diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 7e413cd0640..b5fd8a1e9df 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -139,6 +139,15 @@ else: "SUPPORTED_TP_STYLES", "translate_to_torch_parallel_style", ] +try: + if not is_torch_greater_or_equal("2.5"): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["flex_attention"] = [ + "make_flex_block_causal_mask", + ] if TYPE_CHECKING: from .aqlm import replace_with_aqlm_linear @@ -255,6 +264,13 @@ if TYPE_CHECKING: translate_to_torch_parallel_style, ) + try: + if not is_torch_greater_or_equal("2.5"): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .flex_attention import make_flex_block_causal_mask else: import sys diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 5181b2c1a0a..aff1eb93af7 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -1,4 +1,32 @@ -from typing import Optional, Tuple +""" +Partially inspired by torchtune's flex attention implementation + +Citation: +@software{torchtune, + title = {torchtune: PyTorch's finetuning library}, + author = {torchtune maintainers and contributors}, + url = {https//github.com/pytorch/torchtune}, + license = {BSD-3-Clause}, + month = apr, + year = {2024} +} +""" +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union import torch @@ -6,7 +34,114 @@ from ..utils import is_torch_flex_attn_available if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import flex_attention + from torch.nn.attention.flex_attention import ( + BlockMask, + flex_attention, + ) + from torch.nn.attention.flex_attention import ( + create_block_mask as create_block_causal_mask_flex, + ) + + +class WrappedFlexAttention: + """ + We are doing a singleton class so that flex attention is compiled once when it's first called. + """ + + _instance = None + _is_flex_compiled = False + _compiled_flex_attention = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + # Create a new instance if one doesn't already exist + cls._instance = super().__new__(cls) + return cls._instance + + @torch.compiler.disable(recursive=False) + def __init__(self): + """ + Initialize or update the singleton instance. + """ + if self._is_flex_compiled is False: + self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False) + self._is_flex_compiled = True + + def __call__(self): + return self._compiled_flex_attention + + +def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> BlockMask: + """ + Create a block causal document mask for a batch of sequences, both packed and unpacked. + Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. + The resultant BlockMask is a compressed representation of the full block causal + mask. BlockMask is essential for performant computation of flex attention. + See: https://pytorch.org/blog/flexattention/ + + Args: + attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences + of shape (batch_size, total_seq_len). e.g. + + For unpacked sequence: + [[1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0]] + + For packed sequence: + [[1, 1, 1, 2, 2, 2, 0], + [1, 1, 2, 2, 2, 3, 3]] + + Returns: + BlockMask + """ + device = attention_mask_2d.device + + document_ids = attention_mask_2d + batch_size, total_seq_len = document_ids.shape + + # Instead of passing a tensor mask, flex attention requires a mask_mod function + # that determines which elements of QK^T should be included in the attention + # computation prior to the softmax. For sample packing, we need both the + # logic for both causal mask and document mask. See PyTorch's official + # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods + def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + """ + Defines the logic of a block causal mask by combining both a standard causal mask + and a block diagonal document mask. + + See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` + for an illustration. + """ + causal_mask = q_idx >= kv_idx + document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] + padding_mask = document_ids[batch_idx, q_idx] > 0 + return causal_mask & document_mask & padding_mask + + return create_block_causal_mask_flex( + mask_mod=causal_mask_mod, + B=batch_size, + H=None, # attention head + Q_LEN=total_seq_len, + KV_LEN=total_seq_len, + device=device, + ) + + +@torch.compiler.disable(recursive=False) +def compile_friendly_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + **kwargs, +) -> torch.Tensor: + # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention + flex_attention_compiled = WrappedFlexAttention()() + return flex_attention_compiled( + query, + key, + value, + **kwargs, + ) def flex_attention_forward( @@ -14,30 +149,37 @@ def flex_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: Union[torch.Tensor, BlockMask], scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - causal_mask = attention_mask + block_mask = None + causal_mask = None + if isinstance(attention_mask, BlockMask): + block_mask = attention_mask + else: + causal_mask = attention_mask + if causal_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] - def causal_mod(score, b, h, q_idx, kv_idx): + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): if softcap is not None: score = softcap * torch.tanh(score / softcap) if causal_mask is not None: - score = score + causal_mask[b][0][q_idx][kv_idx] + score = score + causal_mask[batch_idx][0][q_idx][kv_idx] if head_mask is not None: - score = score + head_mask[b][h][0][0] + score = score + head_mask[batch_idx][head_idx][0][0] return score - attn_output, attention_weights = flex_attention( + attn_output, attention_weights = compile_friendly_flex_attention( query, key, value, - score_mod=causal_mod, + score_mod=score_mod, + block_mask=block_mask, enable_gqa=True, scale=scaling, # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 202f4f8ad52..05d5ba89979 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -34,6 +34,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -48,6 +49,12 @@ if is_torch_available(): from torch import nn +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "AriaTextConfig" @@ -1014,6 +1021,11 @@ class AriaTextModel(AriaTextPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index f14dc879b7a..504220c788b 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -36,10 +36,20 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import is_torchdynamo_compiling, logging +from ...utils import ( + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_bloom import BloomConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "bigscience/bloom-560m" @@ -743,6 +753,11 @@ class BloomModel(BloomPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 7510782e5ed..3c3bcb4d011 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -41,6 +41,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -48,6 +49,12 @@ from ...utils import ( from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_2_available(): from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -1389,6 +1396,11 @@ class ChameleonModel(ChameleonPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 11628e8ee61..89e46a5523d 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -26,10 +26,22 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + logging, +) from .configuration_codegen import CodeGenConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Salesforce/codegen-2B-mono" @@ -586,6 +598,11 @@ class CodeGenModel(CodeGenPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 0522b5ec40c..018fc3b8622 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -45,6 +45,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -52,6 +53,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_cohere import CohereConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "CohereConfig" @@ -664,6 +671,11 @@ class CohereModel(CoherePreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 5e8d81415da..44665289488 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -32,6 +32,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -39,6 +40,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_dbrx import DbrxConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -1118,6 +1125,11 @@ class DbrxModel(DbrxPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 6edf83f8a22..f490525d999 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -48,6 +48,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -55,6 +56,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_diffllama import DiffLlamaConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut" @@ -903,6 +910,11 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 43b1d8f93f5..8ad29c02ad5 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -41,6 +41,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -48,6 +49,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -1482,6 +1489,11 @@ class Emu3TextModel(Emu3PreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ddef6a8ca57..08305274725 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -43,6 +43,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_gemma import GemmaConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "google/gemma-7b" @@ -636,6 +643,11 @@ class GemmaModel(GemmaPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index a162a21c0bb..53f488116da 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -43,6 +43,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_glm import GlmConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" @@ -645,6 +652,11 @@ class GlmModel(GlmPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index bc48252578d..e72fff18e00 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -42,12 +42,19 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, is_torch_fx_available, logging, ) from .configuration_gpt_neo import GPTNeoConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -795,6 +802,11 @@ class GPTNeoModel(GPTNeoPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 8ff53763033..9bbd94d798a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -29,12 +29,19 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) from .configuration_gpt_neox import GPTNeoXConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -639,6 +646,11 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 7ea246f8519..3e13ce09c59 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -29,10 +29,19 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel -from ...utils import logging +from ...utils import ( + is_torch_flex_attn_available, + logging, +) from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "abeja/gpt-neox-japanese-2.7b" @@ -665,6 +674,11 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 9f5413e4b46..24f224ad1bb 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -40,6 +40,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, is_torch_fx_proxy, logging, ) @@ -47,6 +48,12 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_gptj import GPTJConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -894,6 +901,11 @@ class GPTJModel(GPTJPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 3a239450657..1822bd627d1 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -37,6 +37,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -44,6 +45,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_granite import GraniteConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "GraniteConfig" @@ -648,6 +655,11 @@ class GraniteModel(GranitePreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 92315e6bd68..73086f958c7 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -37,12 +37,19 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) from .configuration_granitemoe import GraniteMoeConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "GraniteMoeConfig" @@ -1121,6 +1128,11 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 2c0c20e6ddd..3dc86991c73 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -37,12 +37,19 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) from .configuration_granitemoeshared import GraniteMoeSharedConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -1066,6 +1073,11 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 6793bbf201a..c3f57149d80 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -44,6 +44,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -51,6 +52,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_helium import HeliumConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "google/helium-7b" @@ -632,6 +639,11 @@ class HeliumModel(HeliumPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 6b872f421f9..3ca196936c3 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -38,6 +38,7 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -46,6 +47,12 @@ from .perceiver import IdeficsPerceiverResampler from .vision import IdeficsVisionTransformer +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "IdeficsConfig" @@ -1366,6 +1373,11 @@ class IdeficsModel(IdeficsPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index aa9bf4fa68c..814c62a5fd9 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -38,6 +38,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -45,6 +46,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_jetmoe import JetMoeConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -1127,6 +1134,11 @@ class JetMoeModel(JetMoePreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d2bf73e8095..159f41b3cec 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -44,6 +44,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -51,6 +52,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_llama import LlamaConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" @@ -634,6 +641,11 @@ class LlamaModel(LlamaPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index ab5f1b72ab0..9dce3162360 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -40,6 +40,7 @@ from ...utils import ( DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging, @@ -48,6 +49,12 @@ from ...utils import ( from .configuration_longt5 import LongT5Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LongT5Config" @@ -1603,6 +1610,11 @@ class LongT5Stack(LongT5PreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 2ec4cca6ebc..a86d00ed8be 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -32,6 +32,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -40,6 +41,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -1081,6 +1088,11 @@ class MllamaPreTrainedModel(PreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index aac5f40de44..963b7e7aa25 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -46,12 +46,19 @@ from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) from .configuration_moonshine import MoonshineConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MoonshineConfig" @@ -998,6 +1005,11 @@ class MoonshineDecoder(MoonshinePreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 0266292388a..9c2d23b7afb 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -44,6 +44,7 @@ from ...utils import ( DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging, @@ -53,6 +54,11 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_mt5 import MT5Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MT5Config" @@ -1195,6 +1201,11 @@ class MT5Stack(MT5PreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index a2c4805fafb..ef73c55e0fd 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -43,6 +43,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_nemotron import NemotronConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "nvidia/nemotron-3-8b-base-4k-hf" @@ -882,6 +889,11 @@ class NemotronModel(NemotronPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index f9398ad1c5f..2f6dc1fa9c9 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -23,6 +23,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -30,6 +31,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_olmo import OlmoConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "OlmoConfig" @@ -610,6 +617,11 @@ class OlmoModel(OlmoPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 76b8d88ce6b..05b8d172238 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -22,6 +22,7 @@ from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -29,6 +30,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_olmo2 import Olmo2Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Olmo2Config" @@ -611,6 +618,11 @@ class Olmo2Model(Olmo2PreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 22d2bb40b9e..a306e84b23a 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -40,12 +40,19 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) from .configuration_opt import OPTConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -642,6 +649,11 @@ class OPTDecoder(OPTPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index b6060e3e24d..df6153c8118 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -42,6 +42,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -49,6 +50,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_persimmon import PersimmonConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "adept/persimmon-8b-base" @@ -682,6 +689,11 @@ class PersimmonModel(PersimmonPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 2950e27a2fd..cac12d59b00 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -28,6 +28,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -35,6 +36,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_phi import PhiConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "microsoft/phi-1" @@ -608,6 +615,11 @@ class PhiModel(PhiPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 591dbdfe5c2..42e190d2e44 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -39,6 +39,7 @@ from ...utils import ( DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging, @@ -47,6 +48,12 @@ from ...utils import ( from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) # General docstring @@ -1590,6 +1597,11 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 0b9ad5724b9..0c4a2eda978 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -38,6 +38,7 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indi from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging, @@ -46,6 +47,12 @@ from ...utils import ( from .configuration_pop2piano import Pop2PianoConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _load_pop2piano_layer_norm = True @@ -1003,6 +1010,11 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index faa6ac2c813..4621851fea8 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -44,6 +44,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -51,6 +52,12 @@ from ...utils.deprecation import deprecate_kwarg from .configuration_stablelm import StableLmConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -937,6 +944,11 @@ class StableLmModel(StableLmPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0d13dfaabbb..a347654a490 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -40,6 +40,7 @@ from ...utils import ( DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging, @@ -48,6 +49,12 @@ from ...utils import ( from .configuration_switch_transformers import SwitchTransformersConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "SwitchTransformersConfig" @@ -1139,6 +1146,11 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e6f3a74e6d1..ba96c10ed0c 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -44,6 +44,7 @@ from ...utils import ( DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging, @@ -53,6 +54,12 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_t5 import T5Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "T5Config" @@ -1208,6 +1215,11 @@ class T5Stack(T5PreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 3614503e731..2b0e27f45e1 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -43,11 +43,18 @@ from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torchdynamo_compiling, replace_return_docstrings, ) +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.getLogger(__name__) @@ -1541,6 +1548,11 @@ class UdopStack(UdopPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 5a410f1ff7f..7b868696f64 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -41,6 +41,7 @@ from ...utils import ( DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging, @@ -49,6 +50,11 @@ from ...utils import ( from .configuration_umt5 import UMT5Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "UMT5Config" @@ -852,6 +858,11 @@ class UMT5Stack(UMT5PreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index fb892677fc4..2bcf4026a35 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -41,6 +41,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, logging, replace_return_docstrings, ) @@ -48,6 +49,11 @@ from .configuration_whisper import WhisperConfig from .generation_whisper import WhisperGenerationMixin +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -1378,6 +1384,11 @@ class WhisperDecoder(WhisperPreTrainedModel): if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + if isinstance(attention_mask, BlockMask): + return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail