Support for Flash Attention 3 (#38972)

* Support `flash_attn_3`
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged

* Add tests for Flash Attention 2 and 3 parity

* ci fix

* FA2 compatibiity
- `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids`
- Remove bettertransformer check in Flash Attention 3
- Merge tests
- Add licensing

* ci fix

* Test naming consistency

* ci fix

* Deprecation warning for `prepare_fa2_from_position_ids`

* ci fix
This commit is contained in:
EduardDurech 2025-06-25 14:39:27 +02:00 committed by GitHub
parent de98fb25a3
commit a2eb75c891
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 698 additions and 262 deletions

View File

@ -52,6 +52,7 @@ line-ending = "auto"
addopts = "--doctest-glob='**/*.md'" addopts = "--doctest-glob='**/*.md'"
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [ markers = [
"flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin" "generate: marks tests that use the GenerationTesterMixin"

View File

@ -75,6 +75,7 @@ def flash_attention_forward(
softcap=softcap, softcap=softcap,
use_top_left_mask=_use_top_left_mask, use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype, target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
**kwargs, **kwargs,
) )

View File

@ -14,6 +14,7 @@
import inspect import inspect
import os import os
import warnings
from typing import Optional, TypedDict from typing import Optional, TypedDict
import torch import torch
@ -21,6 +22,7 @@ import torch.nn.functional as F
from .utils import ( from .utils import (
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_3_available,
is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10, is_flash_attn_greater_or_equal_2_10,
is_torch_npu_available, is_torch_npu_available,
@ -32,18 +34,123 @@ logger = logging.get_logger(__name__)
flash_attn_func = None flash_attn_func = None
if is_flash_attn_2_available(): def _index_first_axis(tensor, indices):
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa """
from flash_attn import flash_attn_func, flash_attn_varlen_func A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
from flash_attn.layers.rotary import apply_rotary_emb # noqa after flattening the first two dimensions of the tensor. This is functionally equivalent to
FA2's `index_first_axis` and replaces the need to import it.
"""
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
# two dimensions to get (total_tokens, ...) before indexing.
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
return reshaped_tensor[indices]
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
FA3-compatible unpad_input function.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
_index_first_axis(hidden_states, indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
"""
FA3-compatible pad_input function.
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[1:]
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
output[indices] = hidden_states
return output.view(batch, seqlen, *dim)
FA_VERSION = None
if is_flash_attn_2_available():
from flash_attn import flash_attn_func as flash_attn_2_func
from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func
from flash_attn.bert_padding import pad_input as pad_input_fa2
from flash_attn.bert_padding import unpad_input as unpad_input_fa2
from flash_attn.layers.rotary import apply_rotary_emb
HAS_FA2 = True
FA_VERSION = 2
else:
flash_attn_2_func = None
flash_attn_2_varlen_func = None
pad_input_fa2 = None
unpad_input_fa2 = None
apply_rotary_emb = None
HAS_FA2 = False
if is_flash_attn_3_available():
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
pad_input_fa3 = _fa3_pad_input
unpad_input_fa3 = _fa3_unpad_input
HAS_FA3 = True
FA_VERSION = 3
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
pad_input_fa3 = None
unpad_input_fa3 = None
HAS_FA3 = False
# Current Flash Attention implementations
if FA_VERSION:
flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"]
flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"]
unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
pad_input = globals()[f"pad_input_fa{FA_VERSION}"]
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU. # patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
if is_torch_npu_available(): if is_torch_npu_available():
from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input from .integrations.npu_flash_attention import (
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func )
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func from .integrations.npu_flash_attention import (
npu_flash_attn_func as flash_attn_func,
)
from .integrations.npu_flash_attention import (
npu_flash_attn_varlen_func as flash_attn_varlen_func,
)
from .integrations.npu_flash_attention import (
pad_input,
unpad_input,
)
_flash_supports_window_size = False _flash_supports_window_size = False
@ -56,6 +163,9 @@ if flash_attn_func:
def is_flash_attn_available(): def is_flash_attn_available():
"""Determine whether flash-attention can be used or not.""" """Determine whether flash-attention can be used or not."""
if is_flash_attn_3_available():
return True
# if package `flash-attn` is available, flash-attention can be used natively. # if package `flash-attn` is available, flash-attention can be used natively.
if is_flash_attn_2_available(): if is_flash_attn_2_available():
return True return True
@ -70,6 +180,9 @@ def is_flash_attn_available():
def flash_attn_supports_top_left_mask(): def flash_attn_supports_top_left_mask():
"""Determine whether flash-attention uses top-left or down-right mask""" """Determine whether flash-attention uses top-left or down-right mask"""
if is_flash_attn_3_available():
return False
if is_flash_attn_2_available(): if is_flash_attn_2_available():
# top-left mask is used in package `flash-attn` with version lower than 2.1.0 # top-left mask is used in package `flash-attn` with version lower than 2.1.0
return not is_flash_attn_greater_or_equal_2_10() return not is_flash_attn_greater_or_equal_2_10()
@ -116,6 +229,7 @@ def _upad_input(
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
query_length: int, query_length: int,
unpad_input_func,
): ):
""" """
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
@ -134,6 +248,8 @@ def _upad_input(
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`): query_length (`int`):
Target length. Target length.
unpad_input_func:
The function to use for unpadding the input tensors.
Return: Return:
query_layer (`torch.Tensor`): query_layer (`torch.Tensor`):
@ -158,12 +274,10 @@ def _upad_input(
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) key_layer = _index_first_axis(key_layer, indices_k)
value_layer = index_first_axis( value_layer = _index_first_axis(value_layer, indices_k)
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len: if query_length == kv_seq_len:
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) query_layer = _index_first_axis(query_layer, indices_k)
cu_seqlens_q = cu_seqlens_k cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k indices_q = indices_k
@ -177,7 +291,7 @@ def _upad_input(
else: else:
# The -q_len: slice assumes left padding. # The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:] attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input(query_layer, attention_mask) query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
return ( return (
query_layer, query_layer,
@ -189,7 +303,7 @@ def _upad_input(
) )
def prepare_fa2_from_position_ids(query, key, value, position_ids): def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
""" """
This function returns necessary arguments to call `flash_attn_varlen_func`. This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened. All three query, key, value states will be flattened.
@ -239,6 +353,14 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
def prepare_fa2_from_position_ids(*args, **kwargs):
warnings.warn(
"The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.",
FutureWarning,
)
return _prepare_flash_attention_from_position_ids(*args, **kwargs)
def fa_peft_integration_check( def fa_peft_integration_check(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
@ -303,6 +425,7 @@ def _flash_attention_forward(
max_length_q: Optional[int] = None, max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None, max_length_k: Optional[int] = None,
target_dtype: Optional[torch.dtype] = None, target_dtype: Optional[torch.dtype] = None,
attn_implementation: Optional[str] = None,
**kwargs, **kwargs,
): ):
""" """
@ -329,7 +452,28 @@ def _flash_attention_forward(
Softcap for the attention logits, used e.g. in gemma2. Softcap for the attention logits, used e.g. in gemma2.
deterministic (`bool`, *optional*): deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
attn_implementation (`str`, *optional*):
The attention implementation to use. If None, will default to the one based on the environment.
""" """
if attn_implementation is None:
_flash_attn_varlen_func = flash_attn_varlen_func
_flash_attn_func = flash_attn_func
_pad_input = pad_input
_unpad_input = unpad_input
_is_fa3 = HAS_FA3
elif attn_implementation == "flash_attention_3":
_flash_attn_varlen_func = flash_attn_3_varlen_func
_flash_attn_func = flash_attn_3_func
_pad_input = pad_input_fa3
_unpad_input = unpad_input_fa3
_is_fa3 = True
elif attn_implementation == "flash_attention_2":
_flash_attn_varlen_func = flash_attn_2_varlen_func
_flash_attn_func = flash_attn_2_func
_pad_input = pad_input_fa2
_unpad_input = unpad_input_fa2
_is_fa3 = False
if not use_top_left_mask: if not use_top_left_mask:
causal = is_causal causal = is_causal
else: else:
@ -342,6 +486,12 @@ def _flash_attention_forward(
) )
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if _is_fa3:
if dropout > 0.0:
logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.")
else:
flash_kwargs["dropout_p"] = dropout
if flash_241: if flash_241:
if deterministic is None: if deterministic is None:
global deterministic_g global deterministic_g
@ -362,12 +512,12 @@ def _flash_attention_forward(
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query_states, key_states, value_states, attention_mask, query_length query_states, key_states, value_states, attention_mask, query_length, _unpad_input
) )
cu_seqlens_q, cu_seqlens_k = cu_seq_lens cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func( attn_output_unpad = _flash_attn_varlen_func(
query_states, query_states,
key_states, key_states,
value_states, value_states,
@ -375,12 +525,11 @@ def _flash_attention_forward(
cu_seqlens_k=cu_seqlens_k, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=causal, causal=causal,
**flash_kwargs, **flash_kwargs,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
@ -394,7 +543,7 @@ def _flash_attention_forward(
if cu_seq_lens_q is None or cu_seq_lens_k is None: if cu_seq_lens_q is None or cu_seq_lens_k is None:
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) _prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids)
) )
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
@ -405,7 +554,7 @@ def _flash_attention_forward(
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
attn_output = flash_attn_varlen_func( attn_output = _flash_attn_varlen_func(
query_states, query_states,
key_states, key_states,
value_states, value_states,
@ -413,7 +562,6 @@ def _flash_attention_forward(
cu_seqlens_k=cu_seq_lens_k, cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q, max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k, max_seqlen_k=max_length_k,
dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=causal, causal=causal,
**flash_kwargs, **flash_kwargs,
@ -422,10 +570,12 @@ def _flash_attention_forward(
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else: else:
attn_output = flash_attn_func( attn_output = _flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
) )
if isinstance(attn_output, tuple):
return attn_output[0]
return attn_output return attn_output

View File

@ -105,6 +105,7 @@ from .utils import (
is_accelerate_available, is_accelerate_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_3_available,
is_kernels_available, is_kernels_available,
is_offline_mode, is_offline_mode,
is_optimum_available, is_optimum_available,
@ -1957,6 +1958,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Flash Attention 2 support # Flash Attention 2 support
_supports_flash_attn_2 = False _supports_flash_attn_2 = False
# Flash Attention 3 support
_supports_flash_attn_3 = False
# SDPA support # SDPA support
_supports_sdpa = False _supports_sdpa = False
@ -2247,6 +2251,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys() and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
): ):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_3:
message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
if cls._supports_flash_attn_2: if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa: if cls._supports_sdpa:
@ -2282,7 +2288,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
): ):
sub_config._attn_implementation_internal = curr_attn_implementation sub_config._attn_implementation_internal = curr_attn_implementation
if config._attn_implementation == "flash_attention_2": if config._attn_implementation == "flash_attention_3":
cls._check_and_enable_flash_attn_3(
config,
torch_dtype=torch_dtype,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
)
elif config._attn_implementation == "flash_attention_2":
cls._check_and_enable_flash_attn_2( cls._check_and_enable_flash_attn_2(
config, config,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@ -2498,6 +2512,94 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
return config return config
@classmethod
def _check_and_enable_flash_attn_3(
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, dict[str, int]]] = None,
check_device_map: bool = True,
hard_check_only: bool = False,
) -> PretrainedConfig:
"""
Checks the availability of Flash Attention 3 and compatibility with the current model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
"""
if not cls._supports_flash_attn_3:
raise ValueError(
f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
)
if not is_flash_attn_3_available():
preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
if importlib.util.find_spec("flash_attn_3") is None:
raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.")
if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
if major < 9:
raise ValueError(
f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0."
)
else:
raise ImportError(f"{preface} Flash Attention 3 is not available.")
else:
raise ValueError(
f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device."
)
if torch_dtype is None:
logger.warning_once(
"You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
logger.warning_once(
"Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
)
if getattr(config, "alibi", False) or getattr(config, "use_alibi", False):
raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
# Check for attention dropout, which is incompatible with FA3
if hasattr(config, "attention_dropout") and config.attention_dropout > 0:
raise ValueError(
f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3."
)
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
check_device_map
and device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
if not hard_check_only:
config._attn_implementation = "flash_attention_3"
return config
@classmethod @classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
""" """
@ -4134,7 +4236,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
</Tip> </Tip>
attn_implementation (`str`, *optional*): attn_implementation (`str`, *optional*):
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
> Parameters for big model inference > Parameters for big model inference
@ -5770,6 +5872,7 @@ class AttentionInterface(GeneralInterface):
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function) # a new instance is created (in order to locally override a given function)
_global_mapping = { _global_mapping = {
"flash_attention_3": flash_attention_forward,
"flash_attention_2": flash_attention_forward, "flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward, "flex_attention": flex_attention_forward,
"paged_attention": paged_attention_forward, "paged_attention": paged_attention_forward,

View File

@ -321,6 +321,7 @@ class ArceePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["ArceeDecoderLayer"] _no_split_modules = ["ArceeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -667,6 +667,7 @@ class AriaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["AriaDecoderLayer"] _no_split_modules = ["AriaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -318,6 +318,7 @@ class BitNetPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["BitNetDecoderLayer"] _no_split_modules = ["BitNetDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -355,6 +355,7 @@ class CoherePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["CohereDecoderLayer"] _no_split_modules = ["CohereDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -334,6 +334,7 @@ class Cohere2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Cohere2DecoderLayer"] _no_split_modules = ["Cohere2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -504,6 +504,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["DeepseekV3DecoderLayer"] _no_split_modules = ["DeepseekV3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -556,6 +556,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["DiffLlamaDecoderLayer"] _no_split_modules = ["DiffLlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = False _supports_flex_attn = False

View File

@ -424,6 +424,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Dots1DecoderLayer"] _no_split_modules = ["Dots1DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -318,6 +318,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GemmaDecoderLayer"] _no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -339,6 +339,7 @@ class Gemma2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Gemma2DecoderLayer"] _no_split_modules = ["Gemma2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -422,6 +422,7 @@ class Gemma3PreTrainedModel(PreTrainedModel):
"SiglipMultiheadAttentionPoolingHead", "SiglipMultiheadAttentionPoolingHead",
] ]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -335,6 +335,7 @@ class GlmPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GlmDecoderLayer"] _no_split_modules = ["GlmDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -343,6 +343,7 @@ class Glm4PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Glm4DecoderLayer"] _no_split_modules = ["Glm4DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -292,6 +292,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXLayer"] _no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -305,6 +305,7 @@ class GranitePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["GraniteDecoderLayer"] _no_split_modules = ["GraniteDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -320,6 +320,7 @@ class HeliumPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["HeliumDecoderLayer"] _no_split_modules = ["HeliumDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -320,6 +320,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"] _no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -590,6 +590,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["MiniMaxDecoderLayer"] _no_split_modules = ["MiniMaxDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -262,6 +262,7 @@ class MistralPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["MistralDecoderLayer"] _no_split_modules = ["MistralDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -417,6 +417,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["MixtralDecoderLayer"] _no_split_modules = ["MixtralDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -301,6 +301,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["OlmoDecoderLayer"] _no_split_modules = ["OlmoDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -305,6 +305,7 @@ class Olmo2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Olmo2DecoderLayer"] _no_split_modules = ["Olmo2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -295,6 +295,7 @@ class PhiPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["PhiDecoderLayer"] _no_split_modules = ["PhiDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -316,6 +316,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Phi3DecoderLayer"] _no_split_modules = ["Phi3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -1622,6 +1622,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Phi4MultimodalDecoderLayer"] _no_split_modules = ["Phi4MultimodalDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -266,6 +266,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"] _no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -292,6 +292,7 @@ class Qwen3PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3DecoderLayer"] _no_split_modules = ["Qwen3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -424,6 +424,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3MoeDecoderLayer"] _no_split_modules = ["Qwen3MoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -299,6 +299,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Starcoder2DecoderLayer"] _no_split_modules = ["Starcoder2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -561,6 +561,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["T5GemmaBlock"] _no_split_modules = ["T5GemmaBlock"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True _supports_flex_attn = True

View File

@ -86,6 +86,7 @@ from .utils import (
is_faiss_available, is_faiss_available,
is_fbgemm_gpu_available, is_fbgemm_gpu_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_3_available,
is_flax_available, is_flax_available,
is_flute_available, is_flute_available,
is_fsdp_available, is_fsdp_available,
@ -571,6 +572,15 @@ def require_flash_attn(test_case):
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
def require_flash_attn_3(test_case):
"""
Decorator marking a test that requires Flash Attention 3.
These tests are skipped when Flash Attention 3 isn't installed.
"""
return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
def require_torch_sdpa(test_case): def require_torch_sdpa(test_case):
""" """
Decorator marking a test that requires PyTorch's SDPA. Decorator marking a test that requires PyTorch's SDPA.

View File

@ -153,6 +153,7 @@ from .import_utils import (
is_faiss_available, is_faiss_available,
is_fbgemm_gpu_available, is_fbgemm_gpu_available,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_3_available,
is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10, is_flash_attn_greater_or_equal_2_10,
is_flax_available, is_flax_available,

View File

@ -926,6 +926,9 @@ class ClassAttrs:
_skip_keys_device_placement = r""" _skip_keys_device_placement = r"""
A list of keys to ignore when moving inputs or outputs between devices when using the `accelerate` library. A list of keys to ignore when moving inputs or outputs between devices when using the `accelerate` library.
""" """
_supports_flash_attn_3 = r"""
Whether the model's attention implementation supports FlashAttention 3.0.
"""
_supports_flash_attn_2 = r""" _supports_flash_attn_2 = r"""
Whether the model's attention implementation supports FlashAttention 2.0. Whether the model's attention implementation supports FlashAttention 2.0.
""" """

View File

@ -1120,6 +1120,25 @@ def is_flash_attn_2_available():
return False return False
@lru_cache()
def is_flash_attn_3_available():
if not is_torch_available():
return False
if not _is_package_available("flash_attn_3"):
return False
import torch
if not torch.cuda.is_available():
return False
# TODO: Check for a minimum version when FA3 is stable
# return version.parse(importlib.metadata.version("flash_attn_3")) >= version.parse("3.0.0")
return True
@lru_cache @lru_cache
def is_flash_attn_greater_or_equal_2_10(): def is_flash_attn_greater_or_equal_2_10():
if not _is_package_available("flash_attn"): if not _is_package_available("flash_attn"):

View File

@ -0,0 +1,144 @@
# Copyright 2025 Eduard Durech and SGLang team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Usage:
# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
import unittest
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow
class FlashAttentionParityTest(unittest.TestCase):
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
def _lcs(self, X, Y):
m = len(X)
n = len(Y)
L = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
return L[m][n]
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
def _calculate_rouge_l(self, output_strs_list1, output_strs_list2):
rouge_l_scores = []
for s1, s2 in zip(output_strs_list1, output_strs_list2):
lcs_len = self._lcs(s1, s2)
precision = lcs_len / len(s1) if len(s1) > 0 else 0
recall = lcs_len / len(s2) if len(s2) > 0 else 0
if precision + recall > 0:
fmeasure = (2 * precision * recall) / (precision + recall)
else:
fmeasure = 0.0
rouge_l_scores.append(fmeasure)
return rouge_l_scores
def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5):
for _ in range(n_warmup):
model.generate(**inputs, max_new_tokens=20, do_sample=False)
torch.cuda.synchronize()
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
for _ in range(n_runs):
model.generate(**inputs, max_new_tokens=20, do_sample=False)
end_time.record()
torch.cuda.synchronize()
return start_time.elapsed_time(end_time) / n_runs
@pytest.mark.flash_attn_3_test
@require_torch_gpu
@require_flash_attn
@require_flash_attn_3
@slow
def test_flash_attention_2_3_parity(self):
model_id = "meta-llama/Llama-3.2-1B-Instruct"
prompt = "The ETH AI Center is"
# 1. Load FA2 model and tokenizer
model_2 = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# 2. Load FA3 model
try:
model_3 = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_3",
).to("cuda")
except (ValueError, ImportError) as e:
pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}")
# 3. Generate with both models
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output_2 = model_2.generate(
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
)
output_3 = model_3.generate(
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
)
# 4. Correctness check
# 4a. Logits
logits_2 = torch.stack(output_2.scores)
logits_3 = torch.stack(output_3.scores)
torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3)
logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1)
logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1)
max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item()
# 4b. Generated text
text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True)
text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True)
rouge_score = self._calculate_rouge_l([text_2], [text_3])[0]
assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})"
# 5. Performance check
with torch.no_grad():
time_2 = self._benchmark_generation(model_2, inputs)
time_3 = self._benchmark_generation(model_3, inputs)
print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---")
print(f"Prompt: '{prompt}'")
print(f"Generated text with Flash Attention 2: {text_2}")
print(f"Generated text with Flash Attention 3: {text_3}")
print(f"ROUGE-L: {rouge_score}")
print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}")
print(f"Flash Attention 2 latency: {time_2:.2f} ms")
print(f"Flash Attention 3 latency: {time_3:.2f} ms")
print(f"Speed-up: {time_2 / time_3:.2f}x")
print("---")

View File

@ -34,6 +34,7 @@ from transformers.testing_utils import (
is_flaky, is_flaky,
require_accelerate, require_accelerate,
require_flash_attn, require_flash_attn,
require_flash_attn_3,
require_optimum_quanto, require_optimum_quanto,
require_read_token, require_read_token,
require_torch, require_torch,
@ -2292,6 +2293,7 @@ class GenerationTesterMixin:
support_flag = { support_flag = {
"sdpa": "_supports_sdpa", "sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn_2", "flash_attention_2": "_supports_flash_attn_2",
"flash_attention_3": "_supports_flash_attn_3",
} }
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
@ -2369,6 +2371,14 @@ class GenerationTesterMixin:
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations.""" """Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
self._test_attention_implementation("flash_attention_2") self._test_attention_implementation("flash_attention_2")
@pytest.mark.flash_attn_3_test
@require_flash_attn_3
@require_torch_gpu
@slow
def test_eager_matches_fa3_generate(self):
"""Tests that generate has equivalent outputs with FA3 and eager attention implementations."""
self._test_attention_implementation("flash_attention_3")
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
input_batch_size = int(output.sequences.shape[0] / num_return_sequences) input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
internal_batch_size = ( internal_batch_size = (

View File

@ -84,6 +84,7 @@ from transformers.testing_utils import (
require_bitsandbytes, require_bitsandbytes,
require_deepspeed, require_deepspeed,
require_flash_attn, require_flash_attn,
require_flash_attn_3,
require_non_hpu, require_non_hpu,
require_safetensors, require_safetensors,
require_torch, require_torch,
@ -3129,18 +3130,19 @@ class ModelTesterMixin:
f"{model_class} is too big for the common tests ({num_params})! It should have 1M max." f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
) )
@require_flash_attn def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str):
@require_torch_gpu r"""
@mark.flash_attn_test Tests the equivalence between the eager and flash attention implementations.
@slow This test is only for inference and runs with `torch_dtype=torch.bfloat16`.
@is_flaky() """
def test_flash_attn_2_inference_equivalence(self):
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2: if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
@ -3148,7 +3150,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
) )
model_fa.to(torch_device) model_fa.to(torch_device)
@ -3163,9 +3165,12 @@ class ModelTesterMixin:
if dummy_attention_mask is not None: if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1] dummy_attention_mask = dummy_attention_mask[:1]
if padding_side == "left":
dummy_attention_mask[:, 1:] = 1 dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0 dummy_attention_mask[:, :1] = 0
else:
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
@ -3220,11 +3225,22 @@ class ModelTesterMixin:
else outputs_fa.decoder_hidden_states[-1] else outputs_fa.decoder_hidden_states[-1]
) )
if padding_side == "left":
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout # check with inference + dropout
model.train() model.train()
_ = model_fa(dummy_input, **other_inputs) _ = model_fa(dummy_input, **other_inputs)
else:
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="left")
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@ -3232,92 +3248,23 @@ class ModelTesterMixin:
@slow @slow
@is_flaky() @is_flaky()
def test_flash_attn_2_inference_equivalence_right_padding(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
if not self.has_attentions: self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="right")
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes: @require_flash_attn_3
if not model_class._supports_flash_attn_2: @require_torch_gpu
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @mark.flash_attn_3_test
@slow
@is_flaky()
def test_flash_attn_3_inference_equivalence(self):
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="left")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @require_flash_attn_3
model = model_class(config) @require_torch_gpu
@mark.flash_attn_3_test
with tempfile.TemporaryDirectory() as tmpdirname: @slow
model.save_pretrained(tmpdirname) @is_flaky()
model_fa = model_class.from_pretrained( def test_flash_attn_3_inference_equivalence_right_padding(self):
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="right")
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
if model.config.is_encoder_decoder:
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
else:
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
def test_attn_implementation_composite_models(self): def test_attn_implementation_composite_models(self):
""" """
@ -3959,24 +3906,21 @@ class ModelTesterMixin:
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4) torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
) )
@require_flash_attn def flash_attn_can_dispatch_composite_models(self, attn_implementation: str):
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_can_dispatch_composite_models(self):
""" """
Tests if composite models can dispatch on FA2 if the sub-models support FA2. Tests if composite models can dispatch on flash attention if the sub-models support it.
The tests is needed as we handle differently composite models and we cannot check them The tests is needed as we handle differently composite models and we cannot check them
with above tests. If any of the sub-models does not support FA2, we'll raise an error when dispatching with above tests. If any of the sub-models does not support flash attention, we'll raise an error when dispatching
that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific
backbone models (LM/vision/audio/etc) backbone models (LM/vision/audio/etc)
""" """
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
if not is_torch_fp16_available_on_device(torch_device): if not is_torch_bf16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") self.skipTest(f"bfloat16 not supported on {torch_device} (on the specific device currently used)")
torch_dtype = torch.float16 torch_dtype = torch.bfloat16
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
@ -3987,44 +3931,64 @@ class ModelTesterMixin:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
sub_models_supporting_fa2 = [ sub_models_supporting_fa = [
module._supports_flash_attn_2 (
module._supports_flash_attn_3
if attn_implementation == "flash_attention_3"
else module._supports_flash_attn_2
)
for name, module in model.named_modules() for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != "" if isinstance(module, PreTrainedModel) and name != ""
] ]
supports_fa2_all_modules = ( supports_fa_all_modules = (
all(sub_models_supporting_fa2) all(sub_models_supporting_fa)
if len(sub_models_supporting_fa2) > 0 if len(sub_models_supporting_fa) > 0
else (
model._supports_flash_attn_3
if attn_implementation == "flash_attention_3"
else model._supports_flash_attn_2 else model._supports_flash_attn_2
) )
if not supports_fa2_all_modules: )
if not supports_fa_all_modules:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model_fa2 = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, tmpdirname,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
attn_implementation="flash_attention_2", attn_implementation=attn_implementation,
) )
else: else:
model_fa2 = model_class.from_pretrained( model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2" tmpdirname, torch_dtype=torch_dtype, attn_implementation=attn_implementation
) )
for key in model_fa2.config: for key in model_fa.config:
if isinstance(getattr(model_fa2.config, key), PretrainedConfig): if isinstance(getattr(model_fa.config, key), PretrainedConfig):
sub_config = getattr(model_fa2.config, key) sub_config = getattr(model_fa.config, key)
self.assertTrue(sub_config._attn_implementation == "flash_attention_2") self.assertTrue(sub_config._attn_implementation == attn_implementation)
has_fa2 = False has_fa = False
for name, submodule in model_fa2.named_modules(): for name, submodule in model_fa.named_modules():
class_name = submodule.__class__.__name__ class_name = submodule.__class__.__name__
if ( if (
"Attention" in class_name "Attention" in class_name
and getattr(submodule, "config", None) and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "flash_attention_2" and submodule.config._attn_implementation == attn_implementation
): ):
has_fa2 = True has_fa = True
break break
if not has_fa2: if not has_fa:
raise ValueError("The FA2 model should have FA2 layers") raise ValueError(f"The {attn_implementation} model should have {attn_implementation} layers")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_can_dispatch_composite_models(self):
self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_2")
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
def test_flash_attn_3_can_dispatch_composite_models(self):
self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_3")
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@ -4121,27 +4085,29 @@ class ModelTesterMixin:
assert not loss.isnan().any() assert not loss.isnan().any()
@require_flash_attn def flash_attention_padding_matches_padding_free_with_position_ids(
@require_torch_gpu self, attn_implementation: str, fa_kwargs: bool = False
@mark.flash_attn_test ):
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30 max_new_tokens = 30
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if not (
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") model_class._supports_flash_attn_2
if attn_implementation == "flash_attention_2"
else model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask") self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name] dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]: if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.float16) dummy_input = dummy_input.to(torch.bfloat16)
# make sure that all models have enough positions for generation # make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"): if hasattr(config, "max_position_embeddings"):
@ -4151,7 +4117,7 @@ class ModelTesterMixin:
if "position_ids" not in inspect.signature(model.forward).parameters: if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids") self.skipTest("Model does not support position_ids")
if "position_ids" not in inspect.signature(model.forward).parameters: if (not fa_kwargs) and "position_ids" not in inspect.signature(model.forward).parameters:
continue # this model doesn't accept position ids as input continue # this model doesn't accept position ids as input
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@ -4166,13 +4132,27 @@ class ModelTesterMixin:
model = ( model = (
model_class.from_pretrained( model_class.from_pretrained(
tmpdirname, tmpdirname,
torch_dtype=torch.float16, torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", attn_implementation=attn_implementation,
) )
.to(torch_device) .to(torch_device)
.eval() .eval()
) )
if fa_kwargs:
# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
]
# add position_ids + fa_kwargs
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
batch = data_collator(features)
padfree_inputs_dict = {
k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()
}
else:
# flatten # flatten
padfree_inputs_dict = { padfree_inputs_dict = {
k: v[dummy_attention_mask.bool()].unsqueeze(0) k: v[dummy_attention_mask.bool()].unsqueeze(0)
@ -4195,119 +4175,96 @@ class ModelTesterMixin:
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability # acceptable numerical instability
tol = torch.finfo(torch.float16).eps tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
if not self.has_attentions: self.flash_attention_padding_matches_padding_free_with_position_ids(
self.skipTest(reason="Model architecture does not support attentions") attn_implementation="flash_attention_2", fa_kwargs=True
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
.to(torch_device)
.eval()
) )
# flatten @require_flash_attn_3
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
]
# add position_ids + fa_kwargs
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
batch = data_collator(features)
batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()}
res_padded = model(**inputs_dict)
res_padfree = model(**batch_accelerator)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @mark.flash_attn_3_test
@slow @slow
def test_flash_attn_2_from_config(self): def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
@slow
def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(
attn_implementation="flash_attention_3", fa_kwargs=True
)
def flash_attn_from_config(self, attn_implementation: str):
r"""
Tests if the model can be loaded with `attn_implementation` from the config and if the
weights are not randomly initialized.
"""
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2: if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes # TODO: to change it in the future with other relevant auto classes
fa2_model = model_class._from_config( fa_model = model_class._from_config(
config, attn_implementation="flash_attention_2", torch_dtype=torch.float16 config, attn_implementation=attn_implementation, torch_dtype=torch.bfloat16
).to(torch_device) ).to(torch_device)
dummy_input = inputs_dict[fa2_model.main_input_name] dummy_input = inputs_dict[fa_model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]: if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.float16) dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
if fa2_model.config.is_encoder_decoder: if fa_model.config.is_encoder_decoder:
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"] dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"] dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
_ = fa2_model( _ = fa_model(
dummy_input, dummy_input,
attention_mask=dummy_attention_mask, attention_mask=dummy_attention_mask,
decoder_input_ids=dummy_decoder_input_ids, decoder_input_ids=dummy_decoder_input_ids,
decoder_attention_mask=dummy_decoder_attention_mask, decoder_attention_mask=dummy_decoder_attention_mask,
) )
else: else:
_ = fa2_model(dummy_input, attention_mask=dummy_attention_mask) _ = fa_model(dummy_input, attention_mask=dummy_attention_mask)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
fa2_model.save_pretrained(tmpdirname) fa_model.save_pretrained(tmpdirname)
model_from_pretrained = model_class.from_pretrained(tmpdirname) model_from_pretrained = model_class.from_pretrained(tmpdirname)
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2") self.assertTrue(model_from_pretrained.config._attn_implementation != attn_implementation)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_from_config(self):
self.flash_attn_from_config(attn_implementation="flash_attention_2")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_3_test
@slow
def test_flash_attn_3_from_config(self):
self.flash_attn_from_config(attn_implementation="flash_attention_3")
def _get_custom_4d_mask_test_data(self): def _get_custom_4d_mask_test_data(self):
# Sequence in which all but the last token is the same # Sequence in which all but the last token is the same

View File

@ -77,6 +77,7 @@ from transformers.utils import (
) )
from transformers.utils.import_utils import ( from transformers.utils.import_utils import (
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_3_available,
is_flax_available, is_flax_available,
is_tf_available, is_tf_available,
is_torch_npu_available, is_torch_npu_available,
@ -676,6 +677,9 @@ class ModelUtilsTest(TestCasePlus):
if is_flash_attn_available(): if is_flash_attn_available():
attn_implementation_available.append("flash_attention_2") attn_implementation_available.append("flash_attention_2")
if is_flash_attn_3_available():
attn_implementation_available.append("flash_attention_3")
for requested_attn_implementation in attn_implementation_available: for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, attn_implementation=requested_attn_implementation TINY_MISTRAL, attn_implementation=requested_attn_implementation
@ -700,6 +704,9 @@ class ModelUtilsTest(TestCasePlus):
if is_flash_attn_available(): if is_flash_attn_available():
attn_implementation_available.append("flash_attention_2") attn_implementation_available.append("flash_attention_2")
if is_flash_attn_3_available():
attn_implementation_available.append("flash_attention_3")
for requested_attn_implementation in attn_implementation_available: for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation) config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
# Ensure the config was set correctly # Ensure the config was set correctly