Proper_flex (#36643)

* proper performant flex attention implementation

* wrapper for flex attention to compile only when triggered

* wrapper for flex attention to compile only when triggered

* attention mask type detection

* Update src/transformers/integrations/flex_attention.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

* nit

* nit

* nit

* nit

* gemma2 support

* add citation for torchtune

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update flex_attention.py

* nit

* nit

* nit

* reset gemma2 modifications

* nit

* nit

* nit

* licencing

* apply changes to other models

* safe import

---------

Co-authored-by: Sung Ching Liu <sunny19981005@outlook.com>
Co-authored-by: Sung Ching Liu <22844540+bursteratom@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
Arthur 2025-03-11 10:24:12 +01:00 committed by GitHub
parent d8663cb8c5
commit d126f35427
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 645 additions and 12 deletions

View File

@ -139,6 +139,15 @@ else:
"SUPPORTED_TP_STYLES",
"translate_to_torch_parallel_style",
]
try:
if not is_torch_greater_or_equal("2.5"):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["flex_attention"] = [
"make_flex_block_causal_mask",
]
if TYPE_CHECKING:
from .aqlm import replace_with_aqlm_linear
@ -255,6 +264,13 @@ if TYPE_CHECKING:
translate_to_torch_parallel_style,
)
try:
if not is_torch_greater_or_equal("2.5"):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .flex_attention import make_flex_block_causal_mask
else:
import sys

View File

@ -1,4 +1,32 @@
from typing import Optional, Tuple
"""
Partially inspired by torchtune's flex attention implementation
Citation:
@software{torchtune,
title = {torchtune: PyTorch's finetuning library},
author = {torchtune maintainers and contributors},
url = {https//github.com/pytorch/torchtune},
license = {BSD-3-Clause},
month = apr,
year = {2024}
}
"""
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
@ -6,7 +34,114 @@ from ..utils import is_torch_flex_attn_available
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import (
BlockMask,
flex_attention,
)
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
class WrappedFlexAttention:
"""
We are doing a singleton class so that flex attention is compiled once when it's first called.
"""
_instance = None
_is_flex_compiled = False
_compiled_flex_attention = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
# Create a new instance if one doesn't already exist
cls._instance = super().__new__(cls)
return cls._instance
@torch.compiler.disable(recursive=False)
def __init__(self):
"""
Initialize or update the singleton instance.
"""
if self._is_flex_compiled is False:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
self._is_flex_compiled = True
def __call__(self):
return self._compiled_flex_attention
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> BlockMask:
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
Args:
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
of shape (batch_size, total_seq_len). e.g.
For unpacked sequence:
[[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]]
For packed sequence:
[[1, 1, 1, 2, 2, 2, 0],
[1, 1, 2, 2, 2, 3, 3]]
Returns:
BlockMask
"""
device = attention_mask_2d.device
document_ids = attention_mask_2d
batch_size, total_seq_len = document_ids.shape
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
padding_mask = document_ids[batch_idx, q_idx] > 0
return causal_mask & document_mask & padding_mask
return create_block_causal_mask_flex(
mask_mod=causal_mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=total_seq_len,
KV_LEN=total_seq_len,
device=device,
)
@torch.compiler.disable(recursive=False)
def compile_friendly_flex_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
flex_attention_compiled = WrappedFlexAttention()()
return flex_attention_compiled(
query,
key,
value,
**kwargs,
)
def flex_attention_forward(
@ -14,30 +149,37 @@ def flex_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
attention_mask: Union[torch.Tensor, BlockMask],
scaling: Optional[float] = None,
softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
causal_mask = attention_mask
block_mask = None
causal_mask = None
if isinstance(attention_mask, BlockMask):
block_mask = attention_mask
else:
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
def causal_mod(score, b, h, q_idx, kv_idx):
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if causal_mask is not None:
score = score + causal_mask[b][0][q_idx][kv_idx]
score = score + causal_mask[batch_idx][0][q_idx][kv_idx]
if head_mask is not None:
score = score + head_mask[b][h][0][0]
score = score + head_mask[batch_idx][head_idx][0][0]
return score
attn_output, attention_weights = flex_attention(
attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
value,
score_mod=causal_mod,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.

View File

@ -34,6 +34,7 @@ from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -48,6 +49,12 @@ if is_torch_available():
from torch import nn
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "AriaTextConfig"
@ -1014,6 +1021,11 @@ class AriaTextModel(AriaTextPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -36,10 +36,20 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import is_torchdynamo_compiling, logging
from ...utils import (
is_torch_flex_attn_available,
is_torchdynamo_compiling,
logging,
)
from .configuration_bloom import BloomConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
@ -743,6 +753,11 @@ class BloomModel(BloomPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -41,6 +41,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
@ -48,6 +49,12 @@ from ...utils import (
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@ -1389,6 +1396,11 @@ class ChameleonModel(ChameleonPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -26,10 +26,22 @@ from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
)
from .configuration_codegen import CodeGenConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Salesforce/codegen-2B-mono"
@ -586,6 +598,11 @@ class CodeGenModel(CodeGenPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -45,6 +45,7 @@ from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -52,6 +53,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_cohere import CohereConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "CohereConfig"
@ -664,6 +671,11 @@ class CohereModel(CoherePreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -32,6 +32,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -39,6 +40,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_dbrx import DbrxConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -1118,6 +1125,11 @@ class DbrxModel(DbrxPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -48,6 +48,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -55,6 +56,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_diffllama import DiffLlamaConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
@ -903,6 +910,11 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -41,6 +41,7 @@ from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -48,6 +49,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@ -1482,6 +1489,11 @@ class Emu3TextModel(Emu3PreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -43,6 +43,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_gemma import GemmaConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/gemma-7b"
@ -636,6 +643,11 @@ class GemmaModel(GemmaPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -43,6 +43,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_glm import GlmConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b"
@ -645,6 +652,11 @@ class GlmModel(GlmPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -42,12 +42,19 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
is_torch_fx_available,
logging,
)
from .configuration_gpt_neo import GPTNeoConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -795,6 +802,11 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -29,12 +29,19 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_gpt_neox import GPTNeoXConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@ -639,6 +646,11 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -29,10 +29,19 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ...utils import (
is_torch_flex_attn_available,
logging,
)
from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "abeja/gpt-neox-japanese-2.7b"
@ -665,6 +674,11 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -40,6 +40,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
is_torch_fx_proxy,
logging,
)
@ -47,6 +48,12 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_gptj import GPTJConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -894,6 +901,11 @@ class GPTJModel(GPTJPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -37,6 +37,7 @@ from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -44,6 +45,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_granite import GraniteConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GraniteConfig"
@ -648,6 +655,11 @@ class GraniteModel(GranitePreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -37,12 +37,19 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_granitemoe import GraniteMoeConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GraniteMoeConfig"
@ -1121,6 +1128,11 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -37,12 +37,19 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_granitemoeshared import GraniteMoeSharedConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@ -1066,6 +1073,11 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -44,6 +44,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -51,6 +52,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_helium import HeliumConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/helium-7b"
@ -632,6 +639,11 @@ class HeliumModel(HeliumPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -38,6 +38,7 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -46,6 +47,12 @@ from .perceiver import IdeficsPerceiverResampler
from .vision import IdeficsVisionTransformer
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "IdeficsConfig"
@ -1366,6 +1373,11 @@ class IdeficsModel(IdeficsPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -38,6 +38,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -45,6 +46,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_jetmoe import JetMoeConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -1127,6 +1134,11 @@ class JetMoeModel(JetMoePreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -44,6 +44,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -51,6 +52,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_llama import LlamaConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
@ -634,6 +641,11 @@ class LlamaModel(LlamaPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -40,6 +40,7 @@ from ...utils import (
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
@ -48,6 +49,12 @@ from ...utils import (
from .configuration_longt5 import LongT5Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LongT5Config"
@ -1603,6 +1610,11 @@ class LongT5Stack(LongT5PreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -32,6 +32,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
@ -40,6 +41,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@ -1081,6 +1088,11 @@ class MllamaPreTrainedModel(PreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -46,12 +46,19 @@ from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_moonshine import MoonshineConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MoonshineConfig"
@ -998,6 +1005,11 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -44,6 +44,7 @@ from ...utils import (
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
@ -53,6 +54,11 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_mt5 import MT5Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MT5Config"
@ -1195,6 +1201,11 @@ class MT5Stack(MT5PreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -43,6 +43,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_nemotron import NemotronConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "nvidia/nemotron-3-8b-base-4k-hf"
@ -882,6 +889,11 @@ class NemotronModel(NemotronPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -23,6 +23,7 @@ from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -30,6 +31,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_olmo import OlmoConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "OlmoConfig"
@ -610,6 +617,11 @@ class OlmoModel(OlmoPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -22,6 +22,7 @@ from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -29,6 +30,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_olmo2 import Olmo2Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Olmo2Config"
@ -611,6 +618,11 @@ class Olmo2Model(Olmo2PreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -40,12 +40,19 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_opt import OPTConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -642,6 +649,11 @@ class OPTDecoder(OPTPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -42,6 +42,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -49,6 +50,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_persimmon import PersimmonConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "adept/persimmon-8b-base"
@ -682,6 +689,11 @@ class PersimmonModel(PersimmonPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -28,6 +28,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -35,6 +36,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_phi import PhiConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
@ -608,6 +615,11 @@ class PhiModel(PhiPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -39,6 +39,7 @@ from ...utils import (
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
@ -47,6 +48,12 @@ from ...utils import (
from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
# General docstring
@ -1590,6 +1597,11 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -38,6 +38,7 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indi
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
@ -46,6 +47,12 @@ from ...utils import (
from .configuration_pop2piano import Pop2PianoConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_load_pop2piano_layer_norm = True
@ -1003,6 +1010,11 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -44,6 +44,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -51,6 +52,12 @@ from ...utils.deprecation import deprecate_kwarg
from .configuration_stablelm import StableLmConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -937,6 +944,11 @@ class StableLmModel(StableLmPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -40,6 +40,7 @@ from ...utils import (
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
@ -48,6 +49,12 @@ from ...utils import (
from .configuration_switch_transformers import SwitchTransformersConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SwitchTransformersConfig"
@ -1139,6 +1146,11 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -44,6 +44,7 @@ from ...utils import (
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
@ -53,6 +54,12 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_t5 import T5Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
@ -1208,6 +1215,11 @@ class T5Stack(T5PreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -43,11 +43,18 @@ from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torchdynamo_compiling,
replace_return_docstrings,
)
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.getLogger(__name__)
@ -1541,6 +1548,11 @@ class UdopStack(UdopPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -41,6 +41,7 @@ from ...utils import (
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torch_fx_proxy,
is_torchdynamo_compiling,
logging,
@ -49,6 +50,11 @@ from ...utils import (
from .configuration_umt5 import UMT5Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "UMT5Config"
@ -852,6 +858,11 @@ class UMT5Stack(UMT5PreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail

View File

@ -41,6 +41,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
@ -48,6 +49,11 @@ from .configuration_whisper import WhisperConfig
from .generation_whisper import WhisperGenerationMixin
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -1378,6 +1384,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
if isinstance(attention_mask, BlockMask):
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail