Support for Flash Attention 3 (#38972)

* Support `flash_attn_3`
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged

* Add tests for Flash Attention 2 and 3 parity

* ci fix

* FA2 compatibiity
- `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids`
- Remove bettertransformer check in Flash Attention 3
- Merge tests
- Add licensing

* ci fix

* Test naming consistency

* ci fix

* Deprecation warning for `prepare_fa2_from_position_ids`

* ci fix
This commit is contained in:
EduardDurech 2025-06-25 14:39:27 +02:00 committed by GitHub
parent de98fb25a3
commit a2eb75c891
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 698 additions and 262 deletions

View File

@ -52,6 +52,7 @@ line-ending = "auto"
addopts = "--doctest-glob='**/*.md'"
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [
"flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin"

View File

@ -75,6 +75,7 @@ def flash_attention_forward(
softcap=softcap,
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
**kwargs,
)

View File

@ -14,6 +14,7 @@
import inspect
import os
import warnings
from typing import Optional, TypedDict
import torch
@ -21,6 +22,7 @@ import torch.nn.functional as F
from .utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_npu_available,
@ -32,18 +34,123 @@ logger = logging.get_logger(__name__)
flash_attn_func = None
if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb # noqa
def _index_first_axis(tensor, indices):
"""
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
after flattening the first two dimensions of the tensor. This is functionally equivalent to
FA2's `index_first_axis` and replaces the need to import it.
"""
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
# two dimensions to get (total_tokens, ...) before indexing.
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
return reshaped_tensor[indices]
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
FA3-compatible unpad_input function.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
_index_first_axis(hidden_states, indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
"""
FA3-compatible pad_input function.
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[1:]
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
output[indices] = hidden_states
return output.view(batch, seqlen, *dim)
FA_VERSION = None
if is_flash_attn_2_available():
from flash_attn import flash_attn_func as flash_attn_2_func
from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func
from flash_attn.bert_padding import pad_input as pad_input_fa2
from flash_attn.bert_padding import unpad_input as unpad_input_fa2
from flash_attn.layers.rotary import apply_rotary_emb
HAS_FA2 = True
FA_VERSION = 2
else:
flash_attn_2_func = None
flash_attn_2_varlen_func = None
pad_input_fa2 = None
unpad_input_fa2 = None
apply_rotary_emb = None
HAS_FA2 = False
if is_flash_attn_3_available():
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
pad_input_fa3 = _fa3_pad_input
unpad_input_fa3 = _fa3_unpad_input
HAS_FA3 = True
FA_VERSION = 3
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
pad_input_fa3 = None
unpad_input_fa3 = None
HAS_FA3 = False
# Current Flash Attention implementations
if FA_VERSION:
flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"]
flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"]
unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
pad_input = globals()[f"pad_input_fa{FA_VERSION}"]
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
if is_torch_npu_available():
from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
from .integrations.npu_flash_attention import (
npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401
)
from .integrations.npu_flash_attention import (
npu_flash_attn_func as flash_attn_func,
)
from .integrations.npu_flash_attention import (
npu_flash_attn_varlen_func as flash_attn_varlen_func,
)
from .integrations.npu_flash_attention import (
pad_input,
unpad_input,
)
_flash_supports_window_size = False
@ -56,6 +163,9 @@ if flash_attn_func:
def is_flash_attn_available():
"""Determine whether flash-attention can be used or not."""
if is_flash_attn_3_available():
return True
# if package `flash-attn` is available, flash-attention can be used natively.
if is_flash_attn_2_available():
return True
@ -70,6 +180,9 @@ def is_flash_attn_available():
def flash_attn_supports_top_left_mask():
"""Determine whether flash-attention uses top-left or down-right mask"""
if is_flash_attn_3_available():
return False
if is_flash_attn_2_available():
# top-left mask is used in package `flash-attn` with version lower than 2.1.0
return not is_flash_attn_greater_or_equal_2_10()
@ -116,6 +229,7 @@ def _upad_input(
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
unpad_input_func,
):
"""
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
@ -134,6 +248,8 @@ def _upad_input(
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Target length.
unpad_input_func:
The function to use for unpadding the input tensors.
Return:
query_layer (`torch.Tensor`):
@ -158,12 +274,10 @@ def _upad_input(
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
key_layer = _index_first_axis(key_layer, indices_k)
value_layer = _index_first_axis(value_layer, indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
query_layer = _index_first_axis(query_layer, indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
@ -177,7 +291,7 @@ def _upad_input(
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input(query_layer, attention_mask)
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
return (
query_layer,
@ -189,7 +303,7 @@ def _upad_input(
)
def prepare_fa2_from_position_ids(query, key, value, position_ids):
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
@ -239,6 +353,14 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
def prepare_fa2_from_position_ids(*args, **kwargs):
warnings.warn(
"The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.",
FutureWarning,
)
return _prepare_flash_attention_from_position_ids(*args, **kwargs)
def fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
@ -303,6 +425,7 @@ def _flash_attention_forward(
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
target_dtype: Optional[torch.dtype] = None,
attn_implementation: Optional[str] = None,
**kwargs,
):
"""
@ -329,7 +452,28 @@ def _flash_attention_forward(
Softcap for the attention logits, used e.g. in gemma2.
deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
attn_implementation (`str`, *optional*):
The attention implementation to use. If None, will default to the one based on the environment.
"""
if attn_implementation is None:
_flash_attn_varlen_func = flash_attn_varlen_func
_flash_attn_func = flash_attn_func
_pad_input = pad_input
_unpad_input = unpad_input
_is_fa3 = HAS_FA3
elif attn_implementation == "flash_attention_3":
_flash_attn_varlen_func = flash_attn_3_varlen_func
_flash_attn_func = flash_attn_3_func
_pad_input = pad_input_fa3
_unpad_input = unpad_input_fa3
_is_fa3 = True
elif attn_implementation == "flash_attention_2":
_flash_attn_varlen_func = flash_attn_2_varlen_func
_flash_attn_func = flash_attn_2_func
_pad_input = pad_input_fa2
_unpad_input = unpad_input_fa2
_is_fa3 = False
if not use_top_left_mask:
causal = is_causal
else:
@ -342,6 +486,12 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if _is_fa3:
if dropout > 0.0:
logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.")
else:
flash_kwargs["dropout_p"] = dropout
if flash_241:
if deterministic is None:
global deterministic_g
@ -362,12 +512,12 @@ def _flash_attention_forward(
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query_states, key_states, value_states, attention_mask, query_length
query_states, key_states, value_states, attention_mask, query_length, _unpad_input
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
attn_output_unpad = _flash_attn_varlen_func(
query_states,
key_states,
value_states,
@ -375,12 +525,11 @@ def _flash_attention_forward(
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
@ -394,7 +543,7 @@ def _flash_attention_forward(
if cu_seq_lens_q is None or cu_seq_lens_k is None:
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
_prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids)
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
@ -405,7 +554,7 @@ def _flash_attention_forward(
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
attn_output = flash_attn_varlen_func(
attn_output = _flash_attn_varlen_func(
query_states,
key_states,
value_states,
@ -413,7 +562,6 @@ def _flash_attention_forward(
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
@ -422,10 +570,12 @@ def _flash_attention_forward(
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
attn_output = _flash_attn_func(
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)
if isinstance(attn_output, tuple):
return attn_output[0]
return attn_output

View File

@ -105,6 +105,7 @@ from .utils import (
is_accelerate_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
is_flash_attn_3_available,
is_kernels_available,
is_offline_mode,
is_optimum_available,
@ -1957,6 +1958,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Flash Attention 2 support
_supports_flash_attn_2 = False
# Flash Attention 3 support
_supports_flash_attn_3 = False
# SDPA support
_supports_sdpa = False
@ -2247,6 +2251,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_3:
message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
@ -2282,7 +2288,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
):
sub_config._attn_implementation_internal = curr_attn_implementation
if config._attn_implementation == "flash_attention_2":
if config._attn_implementation == "flash_attention_3":
cls._check_and_enable_flash_attn_3(
config,
torch_dtype=torch_dtype,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
)
elif config._attn_implementation == "flash_attention_2":
cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch_dtype,
@ -2498,6 +2512,94 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
config._attn_implementation = "flash_attention_2"
return config
@classmethod
def _check_and_enable_flash_attn_3(
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, dict[str, int]]] = None,
check_device_map: bool = True,
hard_check_only: bool = False,
) -> PretrainedConfig:
"""
Checks the availability of Flash Attention 3 and compatibility with the current model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
"""
if not cls._supports_flash_attn_3:
raise ValueError(
f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
)
if not is_flash_attn_3_available():
preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
if importlib.util.find_spec("flash_attn_3") is None:
raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.")
if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
if major < 9:
raise ValueError(
f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0."
)
else:
raise ImportError(f"{preface} Flash Attention 3 is not available.")
else:
raise ValueError(
f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device."
)
if torch_dtype is None:
logger.warning_once(
"You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
logger.warning_once(
"Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
)
if getattr(config, "alibi", False) or getattr(config, "use_alibi", False):
raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
# Check for attention dropout, which is incompatible with FA3
if hasattr(config, "attention_dropout") and config.attention_dropout > 0:
raise ValueError(
f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3."
)
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
check_device_map
and device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
if not hard_check_only:
config._attn_implementation = "flash_attention_3"
return config
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
@ -4134,7 +4236,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
</Tip>
attn_implementation (`str`, *optional*):
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
> Parameters for big model inference
@ -5770,6 +5872,7 @@ class AttentionInterface(GeneralInterface):
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function)
_global_mapping = {
"flash_attention_3": flash_attention_forward,
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"paged_attention": paged_attention_forward,

View File

@ -321,6 +321,7 @@ class ArceePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["ArceeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -667,6 +667,7 @@ class AriaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["AriaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -318,6 +318,7 @@ class BitNetPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["BitNetDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -355,6 +355,7 @@ class CoherePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["CohereDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -334,6 +334,7 @@ class Cohere2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Cohere2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -504,6 +504,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["DeepseekV3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -556,6 +556,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["DiffLlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = False

View File

@ -424,6 +424,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Dots1DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -318,6 +318,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -339,6 +339,7 @@ class Gemma2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Gemma2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -422,6 +422,7 @@ class Gemma3PreTrainedModel(PreTrainedModel):
"SiglipMultiheadAttentionPoolingHead",
]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -335,6 +335,7 @@ class GlmPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GlmDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -343,6 +343,7 @@ class Glm4PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Glm4DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -292,6 +292,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -305,6 +305,7 @@ class GranitePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GraniteDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -320,6 +320,7 @@ class HeliumPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["HeliumDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -320,6 +320,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -590,6 +590,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MiniMaxDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -262,6 +262,7 @@ class MistralPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MistralDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -417,6 +417,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MixtralDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -301,6 +301,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["OlmoDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -305,6 +305,7 @@ class Olmo2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Olmo2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -295,6 +295,7 @@ class PhiPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["PhiDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -316,6 +316,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Phi3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -1622,6 +1622,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Phi4MultimodalDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -266,6 +266,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -292,6 +292,7 @@ class Qwen3PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -424,6 +424,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3MoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -299,6 +299,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Starcoder2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -561,6 +561,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["T5GemmaBlock"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@ -86,6 +86,7 @@ from .utils import (
is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available,
is_flash_attn_3_available,
is_flax_available,
is_flute_available,
is_fsdp_available,
@ -571,6 +572,15 @@ def require_flash_attn(test_case):
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
def require_flash_attn_3(test_case):
"""
Decorator marking a test that requires Flash Attention 3.
These tests are skipped when Flash Attention 3 isn't installed.
"""
return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
def require_torch_sdpa(test_case):
"""
Decorator marking a test that requires PyTorch's SDPA.

View File

@ -153,6 +153,7 @@ from .import_utils import (
is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available,
is_flash_attn_3_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,

View File

@ -926,6 +926,9 @@ class ClassAttrs:
_skip_keys_device_placement = r"""
A list of keys to ignore when moving inputs or outputs between devices when using the `accelerate` library.
"""
_supports_flash_attn_3 = r"""
Whether the model's attention implementation supports FlashAttention 3.0.
"""
_supports_flash_attn_2 = r"""
Whether the model's attention implementation supports FlashAttention 2.0.
"""

View File

@ -1120,6 +1120,25 @@ def is_flash_attn_2_available():
return False
@lru_cache()
def is_flash_attn_3_available():
if not is_torch_available():
return False
if not _is_package_available("flash_attn_3"):
return False
import torch
if not torch.cuda.is_available():
return False
# TODO: Check for a minimum version when FA3 is stable
# return version.parse(importlib.metadata.version("flash_attn_3")) >= version.parse("3.0.0")
return True
@lru_cache
def is_flash_attn_greater_or_equal_2_10():
if not _is_package_available("flash_attn"):

View File

@ -0,0 +1,144 @@
# Copyright 2025 Eduard Durech and SGLang team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Usage:
# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
import unittest
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow
class FlashAttentionParityTest(unittest.TestCase):
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
def _lcs(self, X, Y):
m = len(X)
n = len(Y)
L = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
return L[m][n]
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
def _calculate_rouge_l(self, output_strs_list1, output_strs_list2):
rouge_l_scores = []
for s1, s2 in zip(output_strs_list1, output_strs_list2):
lcs_len = self._lcs(s1, s2)
precision = lcs_len / len(s1) if len(s1) > 0 else 0
recall = lcs_len / len(s2) if len(s2) > 0 else 0
if precision + recall > 0:
fmeasure = (2 * precision * recall) / (precision + recall)
else:
fmeasure = 0.0
rouge_l_scores.append(fmeasure)
return rouge_l_scores
def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5):
for _ in range(n_warmup):
model.generate(**inputs, max_new_tokens=20, do_sample=False)
torch.cuda.synchronize()
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
for _ in range(n_runs):
model.generate(**inputs, max_new_tokens=20, do_sample=False)
end_time.record()
torch.cuda.synchronize()
return start_time.elapsed_time(end_time) / n_runs
@pytest.mark.flash_attn_3_test
@require_torch_gpu
@require_flash_attn
@require_flash_attn_3
@slow
def test_flash_attention_2_3_parity(self):
model_id = "meta-llama/Llama-3.2-1B-Instruct"
prompt = "The ETH AI Center is"
# 1. Load FA2 model and tokenizer
model_2 = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# 2. Load FA3 model
try:
model_3 = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_3",
).to("cuda")
except (ValueError, ImportError) as e:
pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}")
# 3. Generate with both models
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output_2 = model_2.generate(
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
)
output_3 = model_3.generate(
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
)
# 4. Correctness check
# 4a. Logits
logits_2 = torch.stack(output_2.scores)
logits_3 = torch.stack(output_3.scores)
torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3)
logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1)
logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1)
max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item()
# 4b. Generated text
text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True)
text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True)
rouge_score = self._calculate_rouge_l([text_2], [text_3])[0]
assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})"
# 5. Performance check
with torch.no_grad():
time_2 = self._benchmark_generation(model_2, inputs)
time_3 = self._benchmark_generation(model_3, inputs)
print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---")
print(f"Prompt: '{prompt}'")
print(f"Generated text with Flash Attention 2: {text_2}")
print(f"Generated text with Flash Attention 3: {text_3}")
print(f"ROUGE-L: {rouge_score}")
print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}")
print(f"Flash Attention 2 latency: {time_2:.2f} ms")
print(f"Flash Attention 3 latency: {time_3:.2f} ms")
print(f"Speed-up: {time_2 / time_3:.2f}x")
print("---")

View File

@ -34,6 +34,7 @@ from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_flash_attn,
require_flash_attn_3,
require_optimum_quanto,
require_read_token,
require_torch,
@ -2292,6 +2293,7 @@ class GenerationTesterMixin:
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn_2",
"flash_attention_3": "_supports_flash_attn_3",
}
for model_class in self.all_generative_model_classes:
@ -2369,6 +2371,14 @@ class GenerationTesterMixin:
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
self._test_attention_implementation("flash_attention_2")
@pytest.mark.flash_attn_3_test
@require_flash_attn_3
@require_torch_gpu
@slow
def test_eager_matches_fa3_generate(self):
"""Tests that generate has equivalent outputs with FA3 and eager attention implementations."""
self._test_attention_implementation("flash_attention_3")
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
internal_batch_size = (

View File

@ -84,6 +84,7 @@ from transformers.testing_utils import (
require_bitsandbytes,
require_deepspeed,
require_flash_attn,
require_flash_attn_3,
require_non_hpu,
require_safetensors,
require_torch,
@ -3129,18 +3130,19 @@ class ModelTesterMixin:
f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str):
r"""
Tests the equivalence between the eager and flash attention implementations.
This test is only for inference and runs with `torch_dtype=torch.bfloat16`.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@ -3148,7 +3150,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
)
model_fa.to(torch_device)
@ -3163,9 +3165,12 @@ class ModelTesterMixin:
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
if padding_side == "left":
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
else:
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
@ -3220,11 +3225,22 @@ class ModelTesterMixin:
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
if padding_side == "left":
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
# check with inference + dropout
model.train()
_ = model_fa(dummy_input, **other_inputs)
else:
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="left")
@require_flash_attn
@require_torch_gpu
@ -3232,92 +3248,23 @@ class ModelTesterMixin:
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence_right_padding(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="right")
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
@slow
@is_flaky()
def test_flash_attn_3_inference_equivalence(self):
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="left")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
else:
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
if model.config.is_encoder_decoder:
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
else:
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(dummy_input, **other_inputs)
outputs_fa = model_fa(dummy_input, **other_inputs)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
@slow
@is_flaky()
def test_flash_attn_3_inference_equivalence_right_padding(self):
self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="right")
def test_attn_implementation_composite_models(self):
"""
@ -3959,24 +3906,21 @@ class ModelTesterMixin:
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_can_dispatch_composite_models(self):
def flash_attn_can_dispatch_composite_models(self, attn_implementation: str):
"""
Tests if composite models can dispatch on FA2 if the sub-models support FA2.
Tests if composite models can dispatch on flash attention if the sub-models support it.
The tests is needed as we handle differently composite models and we cannot check them
with above tests. If any of the sub-models does not support FA2, we'll raise an error when dispatching
with above tests. If any of the sub-models does not support flash attention, we'll raise an error when dispatching
that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific
backbone models (LM/vision/audio/etc)
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if not is_torch_bf16_available_on_device(torch_device):
self.skipTest(f"bfloat16 not supported on {torch_device} (on the specific device currently used)")
torch_dtype = torch.float16
torch_dtype = torch.bfloat16
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@ -3987,44 +3931,64 @@ class ModelTesterMixin:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
sub_models_supporting_fa2 = [
module._supports_flash_attn_2
sub_models_supporting_fa = [
(
module._supports_flash_attn_3
if attn_implementation == "flash_attention_3"
else module._supports_flash_attn_2
)
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_fa2_all_modules = (
all(sub_models_supporting_fa2)
if len(sub_models_supporting_fa2) > 0
else model._supports_flash_attn_2
supports_fa_all_modules = (
all(sub_models_supporting_fa)
if len(sub_models_supporting_fa) > 0
else (
model._supports_flash_attn_3
if attn_implementation == "flash_attention_3"
else model._supports_flash_attn_2
)
)
if not supports_fa2_all_modules:
if not supports_fa_all_modules:
with self.assertRaises(ValueError):
model_fa2 = model_class.from_pretrained(
model_fa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
attn_implementation=attn_implementation,
)
else:
model_fa2 = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation=attn_implementation
)
for key in model_fa2.config:
if isinstance(getattr(model_fa2.config, key), PretrainedConfig):
sub_config = getattr(model_fa2.config, key)
self.assertTrue(sub_config._attn_implementation == "flash_attention_2")
for key in model_fa.config:
if isinstance(getattr(model_fa.config, key), PretrainedConfig):
sub_config = getattr(model_fa.config, key)
self.assertTrue(sub_config._attn_implementation == attn_implementation)
has_fa2 = False
for name, submodule in model_fa2.named_modules():
has_fa = False
for name, submodule in model_fa.named_modules():
class_name = submodule.__class__.__name__
if (
"Attention" in class_name
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "flash_attention_2"
and submodule.config._attn_implementation == attn_implementation
):
has_fa2 = True
has_fa = True
break
if not has_fa2:
raise ValueError("The FA2 model should have FA2 layers")
if not has_fa:
raise ValueError(f"The {attn_implementation} model should have {attn_implementation} layers")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_can_dispatch_composite_models(self):
self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_2")
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
def test_flash_attn_3_can_dispatch_composite_models(self):
self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_3")
@require_flash_attn
@require_torch_gpu
@ -4121,27 +4085,29 @@ class ModelTesterMixin:
assert not loss.isnan().any()
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
def flash_attention_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
if not (
model_class._supports_flash_attn_2
if attn_implementation == "flash_attention_2"
else model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
@ -4151,7 +4117,7 @@ class ModelTesterMixin:
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
if "position_ids" not in inspect.signature(model.forward).parameters:
if (not fa_kwargs) and "position_ids" not in inspect.signature(model.forward).parameters:
continue # this model doesn't accept position ids as input
with tempfile.TemporaryDirectory() as tmpdirname:
@ -4166,26 +4132,40 @@ class ModelTesterMixin:
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
)
.to(torch_device)
.eval()
)
# flatten
padfree_inputs_dict = {
k: v[dummy_attention_mask.bool()].unsqueeze(0)
for k, v in inputs_dict.items()
if not k == "attention_mask"
}
# add position_ids
padfree_inputs_dict["position_ids"] = (
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
.long()
.unsqueeze(0)
.to(torch_device)
)
if fa_kwargs:
# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
]
# add position_ids + fa_kwargs
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
batch = data_collator(features)
padfree_inputs_dict = {
k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()
}
else:
# flatten
padfree_inputs_dict = {
k: v[dummy_attention_mask.bool()].unsqueeze(0)
for k, v in inputs_dict.items()
if not k == "attention_mask"
}
# add position_ids
padfree_inputs_dict["position_ids"] = (
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
.long()
.unsqueeze(0)
.to(torch_device)
)
res_padded = model(**inputs_dict)
res_padfree = model(**padfree_inputs_dict)
@ -4195,119 +4175,96 @@ class ModelTesterMixin:
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
self.flash_attention_padding_matches_padding_free_with_position_ids(
attn_implementation="flash_attention_2", fa_kwargs=True
)
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
.to(torch_device)
.eval()
)
# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
]
# add position_ids + fa_kwargs
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
batch = data_collator(features)
batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()}
res_padded = model(**inputs_dict)
res_padfree = model(**batch_accelerator)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@require_flash_attn
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_test
@mark.flash_attn_3_test
@slow
def test_flash_attn_2_from_config(self):
def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
@slow
def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(
attn_implementation="flash_attention_3", fa_kwargs=True
)
def flash_attn_from_config(self, attn_implementation: str):
r"""
Tests if the model can be loaded with `attn_implementation` from the config and if the
weights are not randomly initialized.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes
fa2_model = model_class._from_config(
config, attn_implementation="flash_attention_2", torch_dtype=torch.float16
fa_model = model_class._from_config(
config, attn_implementation=attn_implementation, torch_dtype=torch.bfloat16
).to(torch_device)
dummy_input = inputs_dict[fa2_model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_input = inputs_dict[fa_model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
if fa2_model.config.is_encoder_decoder:
if fa_model.config.is_encoder_decoder:
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
_ = fa2_model(
_ = fa_model(
dummy_input,
attention_mask=dummy_attention_mask,
decoder_input_ids=dummy_decoder_input_ids,
decoder_attention_mask=dummy_decoder_attention_mask,
)
else:
_ = fa2_model(dummy_input, attention_mask=dummy_attention_mask)
_ = fa_model(dummy_input, attention_mask=dummy_attention_mask)
with tempfile.TemporaryDirectory() as tmpdirname:
fa2_model.save_pretrained(tmpdirname)
fa_model.save_pretrained(tmpdirname)
model_from_pretrained = model_class.from_pretrained(tmpdirname)
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
self.assertTrue(model_from_pretrained.config._attn_implementation != attn_implementation)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_from_config(self):
self.flash_attn_from_config(attn_implementation="flash_attention_2")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_3_test
@slow
def test_flash_attn_3_from_config(self):
self.flash_attn_from_config(attn_implementation="flash_attention_3")
def _get_custom_4d_mask_test_data(self):
# Sequence in which all but the last token is the same

View File

@ -77,6 +77,7 @@ from transformers.utils import (
)
from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
is_flax_available,
is_tf_available,
is_torch_npu_available,
@ -676,6 +677,9 @@ class ModelUtilsTest(TestCasePlus):
if is_flash_attn_available():
attn_implementation_available.append("flash_attention_2")
if is_flash_attn_3_available():
attn_implementation_available.append("flash_attention_3")
for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, attn_implementation=requested_attn_implementation
@ -700,6 +704,9 @@ class ModelUtilsTest(TestCasePlus):
if is_flash_attn_available():
attn_implementation_available.append("flash_attention_2")
if is_flash_attn_3_available():
attn_implementation_available.append("flash_attention_3")
for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
# Ensure the config was set correctly