mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Proper_flex (#36643)
* proper performant flex attention implementation * wrapper for flex attention to compile only when triggered * wrapper for flex attention to compile only when triggered * attention mask type detection * Update src/transformers/integrations/flex_attention.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * nit * nit * nit * nit * gemma2 support * add citation for torchtune * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update flex_attention.py * nit * nit * nit * reset gemma2 modifications * nit * nit * nit * licencing * apply changes to other models * safe import --------- Co-authored-by: Sung Ching Liu <sunny19981005@outlook.com> Co-authored-by: Sung Ching Liu <22844540+bursteratom@users.noreply.github.com> Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
parent
d8663cb8c5
commit
d126f35427
@ -139,6 +139,15 @@ else:
|
||||
"SUPPORTED_TP_STYLES",
|
||||
"translate_to_torch_parallel_style",
|
||||
]
|
||||
try:
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["flex_attention"] = [
|
||||
"make_flex_block_causal_mask",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .aqlm import replace_with_aqlm_linear
|
||||
@ -255,6 +264,13 @@ if TYPE_CHECKING:
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .flex_attention import make_flex_block_causal_mask
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -1,4 +1,32 @@
|
||||
from typing import Optional, Tuple
|
||||
"""
|
||||
Partially inspired by torchtune's flex attention implementation
|
||||
|
||||
Citation:
|
||||
@software{torchtune,
|
||||
title = {torchtune: PyTorch's finetuning library},
|
||||
author = {torchtune maintainers and contributors},
|
||||
url = {https//github.com/pytorch/torchtune},
|
||||
license = {BSD-3-Clause},
|
||||
month = apr,
|
||||
year = {2024}
|
||||
}
|
||||
"""
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -6,7 +34,114 @@ from ..utils import is_torch_flex_attn_available
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
flex_attention,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import (
|
||||
create_block_mask as create_block_causal_mask_flex,
|
||||
)
|
||||
|
||||
|
||||
class WrappedFlexAttention:
|
||||
"""
|
||||
We are doing a singleton class so that flex attention is compiled once when it's first called.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_is_flex_compiled = False
|
||||
_compiled_flex_attention = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
# Create a new instance if one doesn't already exist
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@torch.compiler.disable(recursive=False)
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize or update the singleton instance.
|
||||
"""
|
||||
if self._is_flex_compiled is False:
|
||||
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
|
||||
self._is_flex_compiled = True
|
||||
|
||||
def __call__(self):
|
||||
return self._compiled_flex_attention
|
||||
|
||||
|
||||
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> BlockMask:
|
||||
"""
|
||||
Create a block causal document mask for a batch of sequences, both packed and unpacked.
|
||||
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
||||
The resultant BlockMask is a compressed representation of the full block causal
|
||||
mask. BlockMask is essential for performant computation of flex attention.
|
||||
See: https://pytorch.org/blog/flexattention/
|
||||
|
||||
Args:
|
||||
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
|
||||
of shape (batch_size, total_seq_len). e.g.
|
||||
|
||||
For unpacked sequence:
|
||||
[[1, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0]]
|
||||
|
||||
For packed sequence:
|
||||
[[1, 1, 1, 2, 2, 2, 0],
|
||||
[1, 1, 2, 2, 2, 3, 3]]
|
||||
|
||||
Returns:
|
||||
BlockMask
|
||||
"""
|
||||
device = attention_mask_2d.device
|
||||
|
||||
document_ids = attention_mask_2d
|
||||
batch_size, total_seq_len = document_ids.shape
|
||||
|
||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||
# that determines which elements of QK^T should be included in the attention
|
||||
# computation prior to the softmax. For sample packing, we need both the
|
||||
# logic for both causal mask and document mask. See PyTorch's official
|
||||
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
|
||||
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
"""
|
||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||
and a block diagonal document mask.
|
||||
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask = q_idx >= kv_idx
|
||||
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
padding_mask = document_ids[batch_idx, q_idx] > 0
|
||||
return causal_mask & document_mask & padding_mask
|
||||
|
||||
return create_block_causal_mask_flex(
|
||||
mask_mod=causal_mask_mod,
|
||||
B=batch_size,
|
||||
H=None, # attention head
|
||||
Q_LEN=total_seq_len,
|
||||
KV_LEN=total_seq_len,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
@torch.compiler.disable(recursive=False)
|
||||
def compile_friendly_flex_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
|
||||
flex_attention_compiled = WrappedFlexAttention()()
|
||||
return flex_attention_compiled(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def flex_attention_forward(
|
||||
@ -14,30 +149,37 @@ def flex_attention_forward(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
attention_mask: Union[torch.Tensor, BlockMask],
|
||||
scaling: Optional[float] = None,
|
||||
softcap: Optional[float] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
causal_mask = attention_mask
|
||||
block_mask = None
|
||||
causal_mask = None
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
block_mask = attention_mask
|
||||
else:
|
||||
causal_mask = attention_mask
|
||||
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
def causal_mod(score, b, h, q_idx, kv_idx):
|
||||
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
|
||||
if softcap is not None:
|
||||
score = softcap * torch.tanh(score / softcap)
|
||||
if causal_mask is not None:
|
||||
score = score + causal_mask[b][0][q_idx][kv_idx]
|
||||
score = score + causal_mask[batch_idx][0][q_idx][kv_idx]
|
||||
if head_mask is not None:
|
||||
score = score + head_mask[b][h][0][0]
|
||||
score = score + head_mask[batch_idx][head_idx][0][0]
|
||||
return score
|
||||
|
||||
attn_output, attention_weights = flex_attention(
|
||||
attn_output, attention_weights = compile_friendly_flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=causal_mod,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True,
|
||||
scale=scaling,
|
||||
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
|
||||
|
@ -34,6 +34,7 @@ from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -48,6 +49,12 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "AriaTextConfig"
|
||||
|
||||
@ -1014,6 +1021,11 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -36,10 +36,20 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from ...utils import (
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from .configuration_bloom import BloomConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
|
||||
@ -743,6 +753,11 @@ class BloomModel(BloomPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -41,6 +41,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@ -48,6 +49,12 @@ from ...utils import (
|
||||
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
@ -1389,6 +1396,11 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -26,10 +26,22 @@ from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
)
|
||||
from .configuration_codegen import CodeGenConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "Salesforce/codegen-2B-mono"
|
||||
@ -586,6 +598,11 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -45,6 +45,7 @@ from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -52,6 +53,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_cohere import CohereConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "CohereConfig"
|
||||
@ -664,6 +671,11 @@ class CohereModel(CoherePreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -32,6 +32,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -39,6 +40,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_dbrx import DbrxConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
@ -1118,6 +1125,11 @@ class DbrxModel(DbrxPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -48,6 +48,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -55,6 +56,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_diffllama import DiffLlamaConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
|
||||
@ -903,6 +910,11 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -41,6 +41,7 @@ from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -48,6 +49,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -1482,6 +1489,11 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -43,6 +43,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/gemma-7b"
|
||||
@ -636,6 +643,11 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -43,6 +43,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_glm import GlmConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b"
|
||||
@ -645,6 +652,11 @@ class GlmModel(GlmPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -42,12 +42,19 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fx_available,
|
||||
logging,
|
||||
)
|
||||
from .configuration_gpt_neo import GPTNeoConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
@ -795,6 +802,11 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -29,12 +29,19 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_gpt_neox import GPTNeoXConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -639,6 +646,11 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -29,10 +29,19 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from ...utils import (
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
)
|
||||
from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "abeja/gpt-neox-japanese-2.7b"
|
||||
@ -665,6 +674,11 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -40,6 +40,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fx_proxy,
|
||||
logging,
|
||||
)
|
||||
@ -47,6 +48,12 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
||||
from .configuration_gptj import GPTJConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
@ -894,6 +901,11 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -37,6 +37,7 @@ from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -44,6 +45,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_granite import GraniteConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "GraniteConfig"
|
||||
|
||||
@ -648,6 +655,11 @@ class GraniteModel(GranitePreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -37,12 +37,19 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_granitemoe import GraniteMoeConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "GraniteMoeConfig"
|
||||
@ -1121,6 +1128,11 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -37,12 +37,19 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -1066,6 +1073,11 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -44,6 +44,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -51,6 +52,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_helium import HeliumConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/helium-7b"
|
||||
@ -632,6 +639,11 @@ class HeliumModel(HeliumPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -38,6 +38,7 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -46,6 +47,12 @@ from .perceiver import IdeficsPerceiverResampler
|
||||
from .vision import IdeficsVisionTransformer
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "IdeficsConfig"
|
||||
@ -1366,6 +1373,11 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -45,6 +46,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_jetmoe import JetMoeConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
@ -1127,6 +1134,11 @@ class JetMoeModel(JetMoePreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -44,6 +44,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -51,6 +52,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
|
||||
@ -634,6 +641,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -40,6 +40,7 @@ from ...utils import (
|
||||
DUMMY_MASK,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
@ -48,6 +49,12 @@ from ...utils import (
|
||||
from .configuration_longt5 import LongT5Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "LongT5Config"
|
||||
@ -1603,6 +1610,11 @@ class LongT5Stack(LongT5PreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -32,6 +32,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@ -40,6 +41,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -1081,6 +1088,11 @@ class MllamaPreTrainedModel(PreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -46,12 +46,19 @@ from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_moonshine import MoonshineConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "MoonshineConfig"
|
||||
|
||||
@ -998,6 +1005,11 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -44,6 +44,7 @@ from ...utils import (
|
||||
DUMMY_MASK,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
@ -53,6 +54,11 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
||||
from .configuration_mt5 import MT5Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "MT5Config"
|
||||
@ -1195,6 +1201,11 @@ class MT5Stack(MT5PreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -43,6 +43,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -50,6 +51,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_nemotron import NemotronConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "nvidia/nemotron-3-8b-base-4k-hf"
|
||||
@ -882,6 +889,11 @@ class NemotronModel(NemotronPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -23,6 +23,7 @@ from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -30,6 +31,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_olmo import OlmoConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "OlmoConfig"
|
||||
|
||||
@ -610,6 +617,11 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -22,6 +22,7 @@ from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -29,6 +30,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_olmo2 import Olmo2Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "Olmo2Config"
|
||||
|
||||
@ -611,6 +618,11 @@ class Olmo2Model(Olmo2PreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -40,12 +40,19 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_opt import OPTConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
@ -642,6 +649,11 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -42,6 +42,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -49,6 +50,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_persimmon import PersimmonConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "adept/persimmon-8b-base"
|
||||
@ -682,6 +689,11 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -28,6 +28,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -35,6 +36,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_phi import PhiConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
|
||||
@ -608,6 +615,11 @@ class PhiModel(PhiPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -39,6 +39,7 @@ from ...utils import (
|
||||
DUMMY_MASK,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
@ -47,6 +48,12 @@ from ...utils import (
|
||||
from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
@ -1590,6 +1597,11 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -38,6 +38,7 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indi
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
@ -46,6 +47,12 @@ from ...utils import (
|
||||
from .configuration_pop2piano import Pop2PianoConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_load_pop2piano_layer_norm = True
|
||||
@ -1003,6 +1010,11 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -44,6 +44,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -51,6 +52,12 @@ from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_stablelm import StableLmConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
@ -937,6 +944,11 @@ class StableLmModel(StableLmPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -40,6 +40,7 @@ from ...utils import (
|
||||
DUMMY_MASK,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
@ -48,6 +49,12 @@ from ...utils import (
|
||||
from .configuration_switch_transformers import SwitchTransformersConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "SwitchTransformersConfig"
|
||||
@ -1139,6 +1146,11 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -44,6 +44,7 @@ from ...utils import (
|
||||
DUMMY_MASK,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
@ -53,6 +54,12 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
||||
from .configuration_t5 import T5Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "T5Config"
|
||||
@ -1208,6 +1215,11 @@ class T5Stack(T5PreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -43,11 +43,18 @@ from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1541,6 +1548,11 @@ class UdopStack(UdopPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -41,6 +41,7 @@ from ...utils import (
|
||||
DUMMY_MASK,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
@ -49,6 +50,11 @@ from ...utils import (
|
||||
from .configuration_umt5 import UMT5Config
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "UMT5Config"
|
||||
@ -852,6 +858,11 @@ class UMT5Stack(UMT5PreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
@ -41,6 +41,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -48,6 +49,11 @@ from .configuration_whisper import WhisperConfig
|
||||
from .generation_whisper import WhisperGenerationMixin
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
@ -1378,6 +1384,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
if isinstance(attention_mask, BlockMask):
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
|
Loading…
Reference in New Issue
Block a user