From a2eb75c891f6866cc9aeb66896be59f6c4ce100e Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:39:27 +0200 Subject: [PATCH] Support for Flash Attention 3 (#38972) * Support `flash_attn_3` Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper - Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...` An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged * Add tests for Flash Attention 2 and 3 parity * ci fix * FA2 compatibiity - `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids` - Remove bettertransformer check in Flash Attention 3 - Merge tests - Add licensing * ci fix * Test naming consistency * ci fix * Deprecation warning for `prepare_fa2_from_position_ids` * ci fix --- pyproject.toml | 1 + .../integrations/flash_attention.py | 1 + .../modeling_flash_attention_utils.py | 198 +++++++- src/transformers/modeling_utils.py | 107 ++++- .../models/arcee/modeling_arcee.py | 1 + src/transformers/models/aria/modeling_aria.py | 1 + .../models/bitnet/modeling_bitnet.py | 1 + .../models/cohere/modeling_cohere.py | 1 + .../models/cohere2/modeling_cohere2.py | 1 + .../deepseek_v3/modeling_deepseek_v3.py | 1 + .../models/diffllama/modeling_diffllama.py | 1 + .../models/dots1/modeling_dots1.py | 1 + .../models/gemma/modeling_gemma.py | 1 + .../models/gemma2/modeling_gemma2.py | 1 + .../models/gemma3/modeling_gemma3.py | 1 + src/transformers/models/glm/modeling_glm.py | 1 + src/transformers/models/glm4/modeling_glm4.py | 1 + .../models/gpt_neox/modeling_gpt_neox.py | 1 + .../models/granite/modeling_granite.py | 1 + .../models/helium/modeling_helium.py | 1 + .../models/llama/modeling_llama.py | 1 + .../models/minimax/modeling_minimax.py | 1 + .../models/mistral/modeling_mistral.py | 1 + .../models/mixtral/modeling_mixtral.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + .../models/olmo2/modeling_olmo2.py | 1 + src/transformers/models/phi/modeling_phi.py | 1 + src/transformers/models/phi3/modeling_phi3.py | 1 + .../modeling_phi4_multimodal.py | 1 + .../models/qwen2/modeling_qwen2.py | 1 + .../models/qwen3/modeling_qwen3.py | 1 + .../models/qwen3_moe/modeling_qwen3_moe.py | 1 + .../models/starcoder2/modeling_starcoder2.py | 1 + .../models/t5gemma/modeling_t5gemma.py | 1 + src/transformers/testing_utils.py | 10 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/args_doc.py | 3 + src/transformers/utils/import_utils.py | 19 + .../generation/test_flash_attention_parity.py | 144 ++++++ tests/generation/test_utils.py | 10 + tests/test_modeling_common.py | 429 ++++++++---------- tests/utils/test_modeling_utils.py | 7 + 42 files changed, 698 insertions(+), 262 deletions(-) create mode 100644 tests/generation/test_flash_attention_parity.py diff --git a/pyproject.toml b/pyproject.toml index af22cfe9c62..4e7a0c62d0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ line-ending = "auto" addopts = "--doctest-glob='**/*.md'" doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" markers = [ + "flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')", "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", "generate: marks tests that use the GenerationTesterMixin" diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 16fcc909817..00df0ef0fd6 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -75,6 +75,7 @@ def flash_attention_forward( softcap=softcap, use_top_left_mask=_use_top_left_mask, target_dtype=target_dtype, + attn_implementation=module.config._attn_implementation, **kwargs, ) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 7f3df329432..649447ca8f7 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -14,6 +14,7 @@ import inspect import os +import warnings from typing import Optional, TypedDict import torch @@ -21,6 +22,7 @@ import torch.nn.functional as F from .utils import ( is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_torch_npu_available, @@ -32,18 +34,123 @@ logger = logging.get_logger(__name__) flash_attn_func = None -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.layers.rotary import apply_rotary_emb # noqa +def _index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. This is functionally equivalent to + FA2's `index_first_axis` and replaces the need to import it. + """ + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] +def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + FA3-compatible unpad_input function. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + _index_first_axis(hidden_states, indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def _fa3_pad_input(hidden_states, indices, batch, seqlen): + """ + FA3-compatible pad_input function. + + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +FA_VERSION = None +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func as flash_attn_2_func + from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func + from flash_attn.bert_padding import pad_input as pad_input_fa2 + from flash_attn.bert_padding import unpad_input as unpad_input_fa2 + from flash_attn.layers.rotary import apply_rotary_emb + + HAS_FA2 = True + FA_VERSION = 2 +else: + flash_attn_2_func = None + flash_attn_2_varlen_func = None + pad_input_fa2 = None + unpad_input_fa2 = None + apply_rotary_emb = None + HAS_FA2 = False + +if is_flash_attn_3_available(): + from flash_attn_interface import flash_attn_func as flash_attn_3_func + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func + + pad_input_fa3 = _fa3_pad_input + unpad_input_fa3 = _fa3_unpad_input + HAS_FA3 = True + FA_VERSION = 3 +else: + flash_attn_3_func = None + flash_attn_3_varlen_func = None + pad_input_fa3 = None + unpad_input_fa3 = None + HAS_FA3 = False + + +# Current Flash Attention implementations +if FA_VERSION: + flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"] + flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"] + unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"] + pad_input = globals()[f"pad_input_fa{FA_VERSION}"] + # patch functions in package `flash-attn` when using flash-attention on Ascend NPU. if is_torch_npu_available(): - from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input - from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa - from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func - from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func + from .integrations.npu_flash_attention import ( + npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401 + ) + from .integrations.npu_flash_attention import ( + npu_flash_attn_func as flash_attn_func, + ) + from .integrations.npu_flash_attention import ( + npu_flash_attn_varlen_func as flash_attn_varlen_func, + ) + from .integrations.npu_flash_attention import ( + pad_input, + unpad_input, + ) _flash_supports_window_size = False @@ -56,6 +163,9 @@ if flash_attn_func: def is_flash_attn_available(): """Determine whether flash-attention can be used or not.""" + if is_flash_attn_3_available(): + return True + # if package `flash-attn` is available, flash-attention can be used natively. if is_flash_attn_2_available(): return True @@ -70,6 +180,9 @@ def is_flash_attn_available(): def flash_attn_supports_top_left_mask(): """Determine whether flash-attention uses top-left or down-right mask""" + if is_flash_attn_3_available(): + return False + if is_flash_attn_2_available(): # top-left mask is used in package `flash-attn` with version lower than 2.1.0 return not is_flash_attn_greater_or_equal_2_10() @@ -116,6 +229,7 @@ def _upad_input( value_layer: torch.Tensor, attention_mask: torch.Tensor, query_length: int, + unpad_input_func, ): """ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. @@ -134,6 +248,8 @@ def _upad_input( Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. query_length (`int`): Target length. + unpad_input_func: + The function to use for unpadding the input tensors. Return: query_layer (`torch.Tensor`): @@ -158,12 +274,10 @@ def _upad_input( batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + key_layer = _index_first_axis(key_layer, indices_k) + value_layer = _index_first_axis(value_layer, indices_k) if query_length == kv_seq_len: - query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + query_layer = _index_first_axis(query_layer, indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k @@ -177,7 +291,7 @@ def _upad_input( else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) return ( query_layer, @@ -189,7 +303,7 @@ def _upad_input( ) -def prepare_fa2_from_position_ids(query, key, value, position_ids): +def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): """ This function returns necessary arguments to call `flash_attn_varlen_func`. All three query, key, value states will be flattened. @@ -239,6 +353,14 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids): return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) +def prepare_fa2_from_position_ids(*args, **kwargs): + warnings.warn( + "The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.", + FutureWarning, + ) + return _prepare_flash_attention_from_position_ids(*args, **kwargs) + + def fa_peft_integration_check( query: torch.Tensor, key: torch.Tensor, @@ -303,6 +425,7 @@ def _flash_attention_forward( max_length_q: Optional[int] = None, max_length_k: Optional[int] = None, target_dtype: Optional[torch.dtype] = None, + attn_implementation: Optional[str] = None, **kwargs, ): """ @@ -329,7 +452,28 @@ def _flash_attention_forward( Softcap for the attention logits, used e.g. in gemma2. deterministic (`bool`, *optional*): Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. + attn_implementation (`str`, *optional*): + The attention implementation to use. If None, will default to the one based on the environment. """ + if attn_implementation is None: + _flash_attn_varlen_func = flash_attn_varlen_func + _flash_attn_func = flash_attn_func + _pad_input = pad_input + _unpad_input = unpad_input + _is_fa3 = HAS_FA3 + elif attn_implementation == "flash_attention_3": + _flash_attn_varlen_func = flash_attn_3_varlen_func + _flash_attn_func = flash_attn_3_func + _pad_input = pad_input_fa3 + _unpad_input = unpad_input_fa3 + _is_fa3 = True + elif attn_implementation == "flash_attention_2": + _flash_attn_varlen_func = flash_attn_2_varlen_func + _flash_attn_func = flash_attn_2_func + _pad_input = pad_input_fa2 + _unpad_input = unpad_input_fa2 + _is_fa3 = False + if not use_top_left_mask: causal = is_causal else: @@ -342,6 +486,12 @@ def _flash_attention_forward( ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + if _is_fa3: + if dropout > 0.0: + logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.") + else: + flash_kwargs["dropout_p"] = dropout + if flash_241: if deterministic is None: global deterministic_g @@ -362,12 +512,12 @@ def _flash_attention_forward( if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( - query_states, key_states, value_states, attention_mask, query_length + query_states, key_states, value_states, attention_mask, query_length, _unpad_input ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - attn_output_unpad = flash_attn_varlen_func( + attn_output_unpad = _flash_attn_varlen_func( query_states, key_states, value_states, @@ -375,12 +525,11 @@ def _flash_attention_forward( cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs, ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length) # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. @@ -394,7 +543,7 @@ def _flash_attention_forward( if cu_seq_lens_q is None or cu_seq_lens_k is None: query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( - prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) + _prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids) ) cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens @@ -405,7 +554,7 @@ def _flash_attention_forward( key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) - attn_output = flash_attn_varlen_func( + attn_output = _flash_attn_varlen_func( query_states, key_states, value_states, @@ -413,7 +562,6 @@ def _flash_attention_forward( cu_seqlens_k=cu_seq_lens_k, max_seqlen_q=max_length_q, max_seqlen_k=max_length_k, - dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs, @@ -422,10 +570,12 @@ def _flash_attention_forward( attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + attn_output = _flash_attn_func( + query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs ) + if isinstance(attn_output, tuple): + return attn_output[0] return attn_output diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4f6095a3edd..a5d1be345d1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -105,6 +105,7 @@ from .utils import ( is_accelerate_available, is_bitsandbytes_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_kernels_available, is_offline_mode, is_optimum_available, @@ -1957,6 +1958,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Flash Attention 2 support _supports_flash_attn_2 = False + # Flash Attention 3 support + _supports_flash_attn_3 = False + # SDPA support _supports_sdpa = False @@ -2247,6 +2251,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys() ): message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' + if cls._supports_flash_attn_3: + message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' if cls._supports_sdpa: @@ -2282,7 +2288,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi ): sub_config._attn_implementation_internal = curr_attn_implementation - if config._attn_implementation == "flash_attention_2": + if config._attn_implementation == "flash_attention_3": + cls._check_and_enable_flash_attn_3( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + elif config._attn_implementation == "flash_attention_2": cls._check_and_enable_flash_attn_2( config, torch_dtype=torch_dtype, @@ -2498,6 +2512,94 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi config._attn_implementation = "flash_attention_2" return config + @classmethod + def _check_and_enable_flash_attn_3( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ) -> PretrainedConfig: + """ + Checks the availability of Flash Attention 3 and compatibility with the current model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module. + """ + if not cls._supports_flash_attn_3: + raise ValueError( + f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_3_available(): + preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:" + + if importlib.util.find_spec("flash_attn_3") is None: + raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.") + + if torch.cuda.is_available(): + major, _ = torch.cuda.get_device_capability() + if major < 9: + raise ValueError( + f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0." + ) + else: + raise ImportError(f"{preface} Flash Attention 3 is not available.") + else: + raise ValueError( + f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device." + ) + + if torch_dtype is None: + logger.warning_once( + "You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: + logger.warning_once( + "Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but" + f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`' + ) + + if getattr(config, "alibi", False) or getattr(config, "use_alibi", False): + raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.") + + # Check for attention dropout, which is incompatible with FA3 + if hasattr(config, "attention_dropout") and config.attention_dropout > 0: + raise ValueError( + f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3." + ) + + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, + # or the model may be initialized under the context manager `with torch.device("cuda"):`. + if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]: + if torch.cuda.is_available(): + logger.warning_once( + "You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + elif ( + check_device_map + and device_map is not None + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to " + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." + ) + if not hard_check_only: + config._attn_implementation = "flash_attention_3" + return config + @classmethod def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: """ @@ -4134,7 +4236,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi attn_implementation (`str`, *optional*): - The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. + The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. > Parameters for big model inference @@ -5770,6 +5872,7 @@ class AttentionInterface(GeneralInterface): # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if # a new instance is created (in order to locally override a given function) _global_mapping = { + "flash_attention_3": flash_attention_forward, "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, "paged_attention": paged_attention_forward, diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index dc8b7880c41..c224c4300eb 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -321,6 +321,7 @@ class ArceePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ArceeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index f62069a09f4..87f11d19269 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -667,6 +667,7 @@ class AriaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["AriaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index f526802bfca..afafd3f9118 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -318,6 +318,7 @@ class BitNetPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BitNetDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 88ca4e31de1..ad1604bed4a 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -355,6 +355,7 @@ class CoherePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["CohereDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 6999f1632f9..3fec29e9760 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -334,6 +334,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Cohere2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 6eb50621891..541ae6669e9 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -504,6 +504,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DeepseekV3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index fae9f2dbb95..383c329c990 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -556,6 +556,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DiffLlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = False diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index b10fae6dbc8..58b805cca61 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -424,6 +424,7 @@ class Dots1PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Dots1DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 1f8da9ed0ec..04b438c5ab4 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -318,6 +318,7 @@ class GemmaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GemmaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 7008538c7ab..bfd3317946b 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -339,6 +339,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Gemma2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index db15678c25c..084ef0893a7 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -422,6 +422,7 @@ class Gemma3PreTrainedModel(PreTrainedModel): "SiglipMultiheadAttentionPoolingHead", ] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 2ee6273c00d..86538fc25e5 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -335,6 +335,7 @@ class GlmPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GlmDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 75487c5fccf..55cc8869d95 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -343,6 +343,7 @@ class Glm4PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Glm4DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index d3c5141371b..2e563e401f2 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -292,6 +292,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoXLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index d1d69f9579c..b65530c4061 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -305,6 +305,7 @@ class GranitePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GraniteDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 31d9f963049..3a48d931ca1 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -320,6 +320,7 @@ class HeliumPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["HeliumDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3a200ad988b..e79a7697602 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -320,6 +320,7 @@ class LlamaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 0709d31f558..66ed4adcea4 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -590,6 +590,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MiniMaxDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 2576c85a785..4b222eabe23 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -262,6 +262,7 @@ class MistralPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index ae0fd74e566..526bf2bbd75 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -417,6 +417,7 @@ class MixtralPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index c35988e2b8d..fc6a7188623 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -301,6 +301,7 @@ class OlmoPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OlmoDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 8e69f43d3eb..84f5e5ad4e8 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -305,6 +305,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Olmo2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 95164a5f5db..1c513604406 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -295,6 +295,7 @@ class PhiPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PhiDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 79703927021..54fd3d1caf7 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -316,6 +316,7 @@ class Phi3PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Phi3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index a9a902598c1..27c199bf50a 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1622,6 +1622,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index aaebc3c82bd..4ba0b43e134 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -266,6 +266,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 6da04485704..e64f9667597 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -292,6 +292,7 @@ class Qwen3PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 329da67a1e6..47ec0d10ab1 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -424,6 +424,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Qwen3MoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index b0179a518bb..1e1d9c64363 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -299,6 +299,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Starcoder2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 7f3ce0927a5..a6cec1c0997 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -561,6 +561,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["T5GemmaBlock"] _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1a4232adc8c..2ddbd51d414 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -86,6 +86,7 @@ from .utils import ( is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flax_available, is_flute_available, is_fsdp_available, @@ -571,6 +572,15 @@ def require_flash_attn(test_case): return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) +def require_flash_attn_3(test_case): + """ + Decorator marking a test that requires Flash Attention 3. + + These tests are skipped when Flash Attention 3 isn't installed. + """ + return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case) + + def require_torch_sdpa(test_case): """ Decorator marking a test that requires PyTorch's SDPA. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 6d73b8d0325..7ca4c355280 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -153,6 +153,7 @@ from .import_utils import ( is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_flax_available, diff --git a/src/transformers/utils/args_doc.py b/src/transformers/utils/args_doc.py index 00cf4009fa5..61f947516ff 100644 --- a/src/transformers/utils/args_doc.py +++ b/src/transformers/utils/args_doc.py @@ -926,6 +926,9 @@ class ClassAttrs: _skip_keys_device_placement = r""" A list of keys to ignore when moving inputs or outputs between devices when using the `accelerate` library. """ + _supports_flash_attn_3 = r""" + Whether the model's attention implementation supports FlashAttention 3.0. + """ _supports_flash_attn_2 = r""" Whether the model's attention implementation supports FlashAttention 2.0. """ diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 7956f1b22d4..014366cc977 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1120,6 +1120,25 @@ def is_flash_attn_2_available(): return False +@lru_cache() +def is_flash_attn_3_available(): + if not is_torch_available(): + return False + + if not _is_package_available("flash_attn_3"): + return False + + import torch + + if not torch.cuda.is_available(): + return False + + # TODO: Check for a minimum version when FA3 is stable + # return version.parse(importlib.metadata.version("flash_attn_3")) >= version.parse("3.0.0") + + return True + + @lru_cache def is_flash_attn_greater_or_equal_2_10(): if not _is_package_available("flash_attn"): diff --git a/tests/generation/test_flash_attention_parity.py b/tests/generation/test_flash_attention_parity.py new file mode 100644 index 00000000000..187bdfe24cd --- /dev/null +++ b/tests/generation/test_flash_attention_parity.py @@ -0,0 +1,144 @@ +# Copyright 2025 Eduard Durech and SGLang team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Usage: +# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py + +import unittest + +import pytest +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow + + +class FlashAttentionParityTest(unittest.TestCase): + # From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py + def _lcs(self, X, Y): + m = len(X) + n = len(Y) + L = [[0] * (n + 1) for _ in range(m + 1)] + + for i in range(m + 1): + for j in range(n + 1): + if i == 0 or j == 0: + L[i][j] = 0 + elif X[i - 1] == Y[j - 1]: + L[i][j] = L[i - 1][j - 1] + 1 + else: + L[i][j] = max(L[i - 1][j], L[i][j - 1]) + + return L[m][n] + + # From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py + def _calculate_rouge_l(self, output_strs_list1, output_strs_list2): + rouge_l_scores = [] + + for s1, s2 in zip(output_strs_list1, output_strs_list2): + lcs_len = self._lcs(s1, s2) + precision = lcs_len / len(s1) if len(s1) > 0 else 0 + recall = lcs_len / len(s2) if len(s2) > 0 else 0 + if precision + recall > 0: + fmeasure = (2 * precision * recall) / (precision + recall) + else: + fmeasure = 0.0 + rouge_l_scores.append(fmeasure) + + return rouge_l_scores + + def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5): + for _ in range(n_warmup): + model.generate(**inputs, max_new_tokens=20, do_sample=False) + torch.cuda.synchronize() + + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + + start_time.record() + for _ in range(n_runs): + model.generate(**inputs, max_new_tokens=20, do_sample=False) + end_time.record() + torch.cuda.synchronize() + + return start_time.elapsed_time(end_time) / n_runs + + @pytest.mark.flash_attn_3_test + @require_torch_gpu + @require_flash_attn + @require_flash_attn_3 + @slow + def test_flash_attention_2_3_parity(self): + model_id = "meta-llama/Llama-3.2-1B-Instruct" + prompt = "The ETH AI Center is" + + # 1. Load FA2 model and tokenizer + model_2 = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + ).to("cuda") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # 2. Load FA3 model + try: + model_3 = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_3", + ).to("cuda") + except (ValueError, ImportError) as e: + pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}") + + # 3. Generate with both models + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + + with torch.no_grad(): + output_2 = model_2.generate( + **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True + ) + output_3 = model_3.generate( + **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True + ) + + # 4. Correctness check + # 4a. Logits + logits_2 = torch.stack(output_2.scores) + logits_3 = torch.stack(output_3.scores) + torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3) + logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1) + logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1) + max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item() + + # 4b. Generated text + text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True) + text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True) + rouge_score = self._calculate_rouge_l([text_2], [text_3])[0] + assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})" + + # 5. Performance check + with torch.no_grad(): + time_2 = self._benchmark_generation(model_2, inputs) + time_3 = self._benchmark_generation(model_3, inputs) + + print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---") + print(f"Prompt: '{prompt}'") + print(f"Generated text with Flash Attention 2: {text_2}") + print(f"Generated text with Flash Attention 3: {text_3}") + print(f"ROUGE-L: {rouge_score}") + print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}") + print(f"Flash Attention 2 latency: {time_2:.2f} ms") + print(f"Flash Attention 3 latency: {time_3:.2f} ms") + print(f"Speed-up: {time_2 / time_3:.2f}x") + print("---") diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e92d1e1ec77..840d2e66e75 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -34,6 +34,7 @@ from transformers.testing_utils import ( is_flaky, require_accelerate, require_flash_attn, + require_flash_attn_3, require_optimum_quanto, require_read_token, require_torch, @@ -2292,6 +2293,7 @@ class GenerationTesterMixin: support_flag = { "sdpa": "_supports_sdpa", "flash_attention_2": "_supports_flash_attn_2", + "flash_attention_3": "_supports_flash_attn_3", } for model_class in self.all_generative_model_classes: @@ -2369,6 +2371,14 @@ class GenerationTesterMixin: """Tests that generate has equivalent outputs with FA2 and eager attention implementations.""" self._test_attention_implementation("flash_attention_2") + @pytest.mark.flash_attn_3_test + @require_flash_attn_3 + @require_torch_gpu + @slow + def test_eager_matches_fa3_generate(self): + """Tests that generate has equivalent outputs with FA3 and eager attention implementations.""" + self._test_attention_implementation("flash_attention_3") + def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): input_batch_size = int(output.sequences.shape[0] / num_return_sequences) internal_batch_size = ( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f7183089044..a5d9c900680 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -84,6 +84,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_deepspeed, require_flash_attn, + require_flash_attn_3, require_non_hpu, require_safetensors, require_torch, @@ -3129,18 +3130,19 @@ class ModelTesterMixin: f"{model_class} is too big for the common tests ({num_params})! It should have 1M max." ) - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - @is_flaky() - def test_flash_attn_2_inference_equivalence(self): + def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str): + r""" + Tests the equivalence between the eager and flash attention implementations. + This test is only for inference and runs with `torch_dtype=torch.bfloat16`. + """ if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or ( + attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3 + ): + self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3148,7 +3150,7 @@ class ModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation ) model_fa.to(torch_device) @@ -3163,9 +3165,12 @@ class ModelTesterMixin: if dummy_attention_mask is not None: dummy_attention_mask = dummy_attention_mask[:1] - dummy_attention_mask[:, 1:] = 1 - dummy_attention_mask[:, :1] = 0 - + if padding_side == "left": + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + else: + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 if model.config.is_encoder_decoder: decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] @@ -3220,11 +3225,22 @@ class ModelTesterMixin: else outputs_fa.decoder_hidden_states[-1] ) - assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + if padding_side == "left": + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) - # check with inference + dropout - model.train() - _ = model_fa(dummy_input, **other_inputs) + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) + else: + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_2_inference_equivalence(self): + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="left") @require_flash_attn @require_torch_gpu @@ -3232,92 +3248,23 @@ class ModelTesterMixin: @slow @is_flaky() def test_flash_attn_2_inference_equivalence_right_padding(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="right") - for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + @is_flaky() + def test_flash_attn_3_inference_equivalence(self): + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="left") - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) - model_fa.to(torch_device) - - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) - model.to(torch_device) - - dummy_input = inputs_dict[model.main_input_name][:1] - if dummy_input.dtype in [torch.float32, torch.float16]: - dummy_input = dummy_input.to(torch.bfloat16) - - dummy_attention_mask = inputs_dict.get("attention_mask", None) - - if dummy_attention_mask is not None: - dummy_attention_mask = dummy_attention_mask[:1] - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - - if model.config.is_encoder_decoder: - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] - - outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) - else: - outputs = model(dummy_input, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, output_hidden_states=True) - - logits = ( - outputs.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) - - assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) - - if model.config.is_encoder_decoder: - other_inputs = { - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": dummy_attention_mask, - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask - - outputs = model(dummy_input, **other_inputs) - outputs_fa = model_fa(dummy_input, **other_inputs) - else: - other_inputs = { - "output_hidden_states": True, - } - if dummy_attention_mask is not None: - other_inputs["attention_mask"] = dummy_attention_mask - - outputs = model(dummy_input, **other_inputs) - outputs_fa = model_fa(dummy_input, **other_inputs) - - logits = ( - outputs.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs.decoder_hidden_states[-1] - ) - logits_fa = ( - outputs_fa.hidden_states[-1] - if not model.config.is_encoder_decoder - else outputs_fa.decoder_hidden_states[-1] - ) - - assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + @is_flaky() + def test_flash_attn_3_inference_equivalence_right_padding(self): + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="right") def test_attn_implementation_composite_models(self): """ @@ -3959,24 +3906,21 @@ class ModelTesterMixin: torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4) ) - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - def test_flash_attn_2_can_dispatch_composite_models(self): + def flash_attn_can_dispatch_composite_models(self, attn_implementation: str): """ - Tests if composite models can dispatch on FA2 if the sub-models support FA2. + Tests if composite models can dispatch on flash attention if the sub-models support it. The tests is needed as we handle differently composite models and we cannot check them - with above tests. If any of the sub-models does not support FA2, we'll raise an error when dispatching + with above tests. If any of the sub-models does not support flash attention, we'll raise an error when dispatching that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific backbone models (LM/vision/audio/etc) """ if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") - if not is_torch_fp16_available_on_device(torch_device): - self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + if not is_torch_bf16_available_on_device(torch_device): + self.skipTest(f"bfloat16 not supported on {torch_device} (on the specific device currently used)") - torch_dtype = torch.float16 + torch_dtype = torch.bfloat16 for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3987,44 +3931,64 @@ class ModelTesterMixin: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) - sub_models_supporting_fa2 = [ - module._supports_flash_attn_2 + sub_models_supporting_fa = [ + ( + module._supports_flash_attn_3 + if attn_implementation == "flash_attention_3" + else module._supports_flash_attn_2 + ) for name, module in model.named_modules() if isinstance(module, PreTrainedModel) and name != "" ] - supports_fa2_all_modules = ( - all(sub_models_supporting_fa2) - if len(sub_models_supporting_fa2) > 0 - else model._supports_flash_attn_2 + supports_fa_all_modules = ( + all(sub_models_supporting_fa) + if len(sub_models_supporting_fa) > 0 + else ( + model._supports_flash_attn_3 + if attn_implementation == "flash_attention_3" + else model._supports_flash_attn_2 + ) ) - if not supports_fa2_all_modules: + if not supports_fa_all_modules: with self.assertRaises(ValueError): - model_fa2 = model_class.from_pretrained( + model_fa = model_class.from_pretrained( tmpdirname, torch_dtype=torch_dtype, - attn_implementation="flash_attention_2", + attn_implementation=attn_implementation, ) else: - model_fa2 = model_class.from_pretrained( - tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2" + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch_dtype, attn_implementation=attn_implementation ) - for key in model_fa2.config: - if isinstance(getattr(model_fa2.config, key), PretrainedConfig): - sub_config = getattr(model_fa2.config, key) - self.assertTrue(sub_config._attn_implementation == "flash_attention_2") + for key in model_fa.config: + if isinstance(getattr(model_fa.config, key), PretrainedConfig): + sub_config = getattr(model_fa.config, key) + self.assertTrue(sub_config._attn_implementation == attn_implementation) - has_fa2 = False - for name, submodule in model_fa2.named_modules(): + has_fa = False + for name, submodule in model_fa.named_modules(): class_name = submodule.__class__.__name__ if ( "Attention" in class_name and getattr(submodule, "config", None) - and submodule.config._attn_implementation == "flash_attention_2" + and submodule.config._attn_implementation == attn_implementation ): - has_fa2 = True + has_fa = True break - if not has_fa2: - raise ValueError("The FA2 model should have FA2 layers") + if not has_fa: + raise ValueError(f"The {attn_implementation} model should have {attn_implementation} layers") + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + def test_flash_attn_2_can_dispatch_composite_models(self): + self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_2") + + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + def test_flash_attn_3_can_dispatch_composite_models(self): + self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_3") @require_flash_attn @require_torch_gpu @@ -4121,27 +4085,29 @@ class ModelTesterMixin: assert not loss.isnan().any() - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + def flash_attention_padding_matches_padding_free_with_position_ids( + self, attn_implementation: str, fa_kwargs: bool = False + ): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") max_new_tokens = 30 for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + if not ( + model_class._supports_flash_attn_2 + if attn_implementation == "flash_attention_2" + else model_class._supports_flash_attn_3 + ): + self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: self.skipTest("Model dummy inputs should contain padding in their attention mask") dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) # make sure that all models have enough positions for generation if hasattr(config, "max_position_embeddings"): @@ -4151,7 +4117,7 @@ class ModelTesterMixin: if "position_ids" not in inspect.signature(model.forward).parameters: self.skipTest("Model does not support position_ids") - if "position_ids" not in inspect.signature(model.forward).parameters: + if (not fa_kwargs) and "position_ids" not in inspect.signature(model.forward).parameters: continue # this model doesn't accept position ids as input with tempfile.TemporaryDirectory() as tmpdirname: @@ -4166,26 +4132,40 @@ class ModelTesterMixin: model = ( model_class.from_pretrained( tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + attn_implementation=attn_implementation, ) .to(torch_device) .eval() ) - # flatten - padfree_inputs_dict = { - k: v[dummy_attention_mask.bool()].unsqueeze(0) - for k, v in inputs_dict.items() - if not k == "attention_mask" - } - # add position_ids - padfree_inputs_dict["position_ids"] = ( - torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]) - .long() - .unsqueeze(0) - .to(torch_device) - ) + if fa_kwargs: + # flatten + features = [ + {"input_ids": i[a.bool()].tolist()} + for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"]) + ] + + # add position_ids + fa_kwargs + data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True) + batch = data_collator(features) + padfree_inputs_dict = { + k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items() + } + else: + # flatten + padfree_inputs_dict = { + k: v[dummy_attention_mask.bool()].unsqueeze(0) + for k, v in inputs_dict.items() + if not k == "attention_mask" + } + # add position_ids + padfree_inputs_dict["position_ids"] = ( + torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]) + .long() + .unsqueeze(0) + .to(torch_device) + ) res_padded = model(**inputs_dict) res_padfree = model(**padfree_inputs_dict) @@ -4195,119 +4175,96 @@ class ModelTesterMixin: torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) # acceptable numerical instability - tol = torch.finfo(torch.float16).eps + tol = torch.finfo(torch.bfloat16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2") + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @slow def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") + self.flash_attention_padding_matches_padding_free_with_position_ids( + attn_implementation="flash_attention_2", fa_kwargs=True + ) - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: - self.skipTest("Model dummy inputs should contain padding in their attention mask") - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - if "position_ids" not in inspect.signature(model.forward).parameters: - self.skipTest("Model does not support position_ids") - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - # ensure left padding, to adapt for some models - if 0 in inputs_dict["attention_mask"][:, -1]: - inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) - dummy_attention_mask = inputs_dict["attention_mask"] - inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id - - model = ( - model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - ) - .to(torch_device) - .eval() - ) - - # flatten - features = [ - {"input_ids": i[a.bool()].tolist()} - for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"]) - ] - - # add position_ids + fa_kwargs - data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True) - batch = data_collator(features) - batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()} - - res_padded = model(**inputs_dict) - res_padfree = model(**batch_accelerator) - - logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] - logits_padfree = res_padfree.logits[0] - - torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) - # acceptable numerical instability - tol = torch.finfo(torch.float16).eps - torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) - - @require_flash_attn + @require_flash_attn_3 @require_torch_gpu - @mark.flash_attn_test + @mark.flash_attn_3_test @slow - def test_flash_attn_2_from_config(self): + def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self): + self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3") + + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): + self.flash_attention_padding_matches_padding_free_with_position_ids( + attn_implementation="flash_attention_3", fa_kwargs=True + ) + + def flash_attn_from_config(self, attn_implementation: str): + r""" + Tests if the model can be loaded with `attn_implementation` from the config and if the + weights are not randomly initialized. + """ if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or ( + attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3 + ): + self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() # TODO: to change it in the future with other relevant auto classes - fa2_model = model_class._from_config( - config, attn_implementation="flash_attention_2", torch_dtype=torch.float16 + fa_model = model_class._from_config( + config, attn_implementation=attn_implementation, torch_dtype=torch.bfloat16 ).to(torch_device) - dummy_input = inputs_dict[fa2_model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) + dummy_input = inputs_dict[fa_model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - if fa2_model.config.is_encoder_decoder: + if fa_model.config.is_encoder_decoder: dummy_decoder_input_ids = inputs_dict["decoder_input_ids"] dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"] - _ = fa2_model( + _ = fa_model( dummy_input, attention_mask=dummy_attention_mask, decoder_input_ids=dummy_decoder_input_ids, decoder_attention_mask=dummy_decoder_attention_mask, ) else: - _ = fa2_model(dummy_input, attention_mask=dummy_attention_mask) + _ = fa_model(dummy_input, attention_mask=dummy_attention_mask) with tempfile.TemporaryDirectory() as tmpdirname: - fa2_model.save_pretrained(tmpdirname) + fa_model.save_pretrained(tmpdirname) model_from_pretrained = model_class.from_pretrained(tmpdirname) - self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2") + self.assertTrue(model_from_pretrained.config._attn_implementation != attn_implementation) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_from_config(self): + self.flash_attn_from_config(attn_implementation="flash_attention_2") + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_flash_attn_3_from_config(self): + self.flash_attn_from_config(attn_implementation="flash_attention_3") def _get_custom_4d_mask_test_data(self): # Sequence in which all but the last token is the same diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 903283dd4a9..7df23e02959 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -77,6 +77,7 @@ from transformers.utils import ( ) from transformers.utils.import_utils import ( is_flash_attn_2_available, + is_flash_attn_3_available, is_flax_available, is_tf_available, is_torch_npu_available, @@ -676,6 +677,9 @@ class ModelUtilsTest(TestCasePlus): if is_flash_attn_available(): attn_implementation_available.append("flash_attention_2") + if is_flash_attn_3_available(): + attn_implementation_available.append("flash_attention_3") + for requested_attn_implementation in attn_implementation_available: model = AutoModelForCausalLM.from_pretrained( TINY_MISTRAL, attn_implementation=requested_attn_implementation @@ -700,6 +704,9 @@ class ModelUtilsTest(TestCasePlus): if is_flash_attn_available(): attn_implementation_available.append("flash_attention_2") + if is_flash_attn_3_available(): + attn_implementation_available.append("flash_attention_3") + for requested_attn_implementation in attn_implementation_available: config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation) # Ensure the config was set correctly