[GPTNeoX] Flex Attention + Refactor (#34896)

* gpt neox flex attention + refactor

* some formatting

* small fix on dropout

* add assertion on flex attn test

* flaky ci :(

* add head mask support

* style

* handle dtype, replace torch where

* fixup flex with output attns

* code review and several other fixes

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* style

* remove unnecessary comment

* remove incorrect comment

* make flex attn check more agnostic tor versions and centralized

* change peft input dtype check to value since q and k could be affected by other stuff like RoPE

* i forgor

* flaky

* code review and small fixes

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Anton Vlasjuk 2024-12-04 14:48:28 +01:00 committed by GitHub
parent accb7204f9
commit 46df859975
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 372 additions and 250 deletions

View File

@ -20,7 +20,10 @@ from typing import Optional, Tuple, TypedDict
import torch
import torch.nn.functional as F
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal, logging
logger = logging.get_logger(__name__)
if is_flash_attn_2_available():
@ -180,6 +183,47 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
def fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Args:
query (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value (`torch.Tensor`):
Input value states to be passed to Flash Attention API
target_dtype (`torch.dtype`, *optional*):
The dtype to convert the attention tensors to. Conversion can be ignored by
not providing the target dtype.
"""
if target_dtype is None:
return query, key, value
input_dtype = value.dtype
if input_dtype == torch.float32:
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
return query, key, value
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
@ -202,6 +246,7 @@ def _flash_attention_forward(
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
target_dtype: Optional[torch.dtype] = None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@ -248,6 +293,11 @@ def _flash_attention_forward(
if softcap is not None:
flash_kwargs["softcap"] = softcap
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype
)
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]

View File

@ -89,6 +89,7 @@ from .utils import (
is_peft_available,
is_remote_url,
is_safetensors_available,
is_torch_flex_attn_available,
is_torch_greater_or_equal,
is_torch_sdpa_available,
is_torch_xla_available,
@ -1342,6 +1343,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# SDPA support
_supports_sdpa = False
# Flex Attention support
_supports_flex_attn = False
# Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
_supports_cache_class = False
_supports_static_cache = False
@ -1548,6 +1552,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
message += (
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
)
raise ValueError(message + ".")
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
@ -1582,6 +1590,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hard_check_only=False,
check_device_map=check_device_map,
)
elif requested_attn_implementation == "flex_attention":
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
@ -1778,7 +1788,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
Checks the availability of SDPA for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_sdpa:
@ -1803,6 +1813,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config._attn_implementation = "sdpa"
return config
@classmethod
def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
Checks the availability of Flex Attention for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_flex_attn:
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch's flex_attention."
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
" If you believe this error is a bug, please open an issue in Transformers GitHub repository"
' and load your model with the argument `attn_implementation="eager"` meanwhile.'
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
if not is_torch_flex_attn_available():
raise ImportError(
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
)
if not is_torch_flex_attn_available() or not cls._supports_flex_attn:
return config
if not hard_check_only:
config._attn_implementation = "flex_attention"
return config
def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping

View File

@ -18,7 +18,6 @@ from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -42,9 +41,9 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
get_torch_version,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
logging,
)
from .configuration_gpt_neox import GPTNeoXConfig
@ -53,6 +52,9 @@ from .configuration_gpt_neox import GPTNeoXConfig
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import flex_attention
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM"
@ -76,6 +78,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
_supports_quantized_cache = True
_supports_static_cache = True
_supports_sdpa = True
_supports_flex_attn = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -92,6 +95,169 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
def eager_attention_forward(
query, key, value, attention_mask, head_mask, norm_factor, attention_dropout, training, **_kwargs
):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
batch_size * num_attention_heads,
query_length,
key_length,
dtype=query.dtype,
device=key.device,
)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=norm_factor,
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_scores = attn_scores + causal_mask
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training)
attn_output = torch.matmul(attn_weights, value)
# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def flash_attention_forward(
query,
key,
value,
attention_mask,
norm_factor,
attention_dropout,
training,
target_dtype=None,
**_kwargs,
):
query_length = query.shape[-2]
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
query = query.to(value.dtype)
key = key.to(value.dtype)
# Permute to get the expected shape for Flash Attention
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attention_dropout = attention_dropout if training else 0.0
flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Compute attention
attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
query_length,
dropout=attention_dropout,
softmax_scale=norm_factor,
is_causal=True,
use_top_left_mask=flash_attn_uses_top_left_mask,
target_dtype=target_dtype,
)
return attn_output, None
def sdpa_attention_forward(query, key, value, attention_mask, attention_dropout, training, **_kwargs):
q_len = query.shape[-2]
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
query = query.to(value.dtype)
key = key.to(value.dtype)
# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=causal_mask,
dropout_p=attention_dropout if training else 0.0,
is_causal=is_causal,
)
# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def flex_attention_forward(query, key, value, attention_mask, head_mask, norm_factor, **_kwargs):
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
def causal_mod(score, b, h, q_idx, kv_idx):
if causal_mask is not None:
score += causal_mask[b][0][q_idx][kv_idx]
if head_mask is not None:
score += head_mask[b][h][0][0]
return score
attn_output, attn_weights = flex_attention(
query,
key,
value,
score_mod=causal_mod,
enable_gqa=True,
scale=norm_factor,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
)
# lse is returned in float32
attn_weights = attn_weights.to(value.dtype)
# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
GPTNEOX_ATTENTION_FUNCTION = {
"eager": eager_attention_forward,
"flash_attention_2": flash_attention_forward,
"sdpa": sdpa_attention_forward,
"flex_attention": flex_attention_forward,
}
class GPTNeoXAttention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
@ -147,20 +313,57 @@ class GPTNeoXAttention(nn.Module):
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
bsz, seq_len, _ = hidden_states.shape
# Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope(
hidden_states=hidden_states,
position_ids=position_ids,
layer_past=layer_past,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# Checking for fallbacks in case an unsupported feature is requested
attention_type = self.config._attn_implementation
if (output_attentions or head_mask is not None) and self.config._attn_implementation in [
"sdpa",
"flash_attention_2",
]:
logger.warning_once(
f"Setting `attention_type` to `eager` because `{attention_type}` does not support"
f" `output_attentions=True` or `head_mask`."
)
attention_type = "eager"
# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
elif (
self.training
and self.config.attention_dropout > 0
and self.config._attn_implementation == "flex_attention"
):
logger.warning_once(
f"Setting `attention_type` to `eager` because `dropout` is not supported in `{attention_type}`."
)
attention_type = "eager"
# Compute attention
attn_output, attn_weights = GPTNEOX_ATTENTION_FUNCTION[attention_type](
query,
key,
value,
attention_mask=attention_mask,
head_mask=head_mask,
norm_factor=self.norm_factor,
attention_dropout=self.config.attention_dropout,
training=self.training,
# Flash Attention 2 specific PEFT check
target_dtype=self._fa_peft_dtype_check(value),
)
# Reshape outputs and final projection
attn_output = attn_output.contiguous()
attn_output = attn_output.view(bsz, seq_len, -1)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
@ -250,115 +453,15 @@ class GPTNeoXAttention(nn.Module):
return query, key, value, layer_past
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
# dynamically increase the causal mask with the key length, if needed.
if key_length > self.bias.shape[-1]:
self._init_bias(key_length, device=key.device)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
batch_size * num_attention_heads,
query_length,
key_length,
dtype=query.dtype,
device=key.device,
)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=self.norm_factor,
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_scores = attn_scores + causal_mask
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class GPTNeoXFlashAttention2(GPTNeoXAttention):
def _fa_peft_dtype_check(self, value):
"""
GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
PEFT can silently cast the dtype to float32 - this method returns the target dtype to which
FA should convert back to (if necessary). For now, we can not move this to the forward pass
itself due to the dependency on checking on some part of its own weights (last case).
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
# Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope(
hidden_states=hidden_states,
position_ids=position_ids,
layer_past=layer_past,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
query_length = query.shape[-2]
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
target_dtype = value.dtype
if query.dtype != target_dtype:
query = query.to(target_dtype)
if key.dtype != target_dtype:
key = key.to(target_dtype)
# Permute to get the expected shape for Flash Attention
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 / bfloat16 just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
input_dtype = query.dtype
target_dtype = None
if self.config._attn_implementation == "flash_attention_2":
input_dtype = value.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
@ -367,144 +470,29 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.query_key_value.weight.dtype
return target_dtype
# TODO Remove in deprecation cycle
class GPTNeoXFlashAttention2(GPTNeoXAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
"The `GPTNeoXFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GPTNeoXAttention` class! It will be removed in v4.48"
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
attention_dropout = self.config.attention_dropout if self.training else 0.0
# Compute attention
attn_weights = _flash_attention_forward(
query,
key,
value,
attention_mask,
query_length,
dropout=attention_dropout,
softmax_scale=self.norm_factor,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
# Reshape outputs
attn_output = attn_weights.reshape(
attn_weights.shape[0], attn_weights.shape[1], self.num_attention_heads * self.head_size
)
attn_output = self.dense(attn_output)
outputs = (attn_output, layer_past)
if output_attentions:
outputs += (attn_weights,)
return outputs
# TODO Remove in deprecation cycle
class GPTNeoXSdpaAttention(GPTNeoXAttention):
"""
GPTNeoX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`GPTNeoXAttention` as the weights of the module stays untouched. The only changes are on the forward pass
to adapt to the SDPA API.
"""
def __init__(self, config, layer_idx=None):
super().__init__(config, layer_idx=layer_idx)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
if output_attentions or head_mask is not None:
logger.warning_once(
"`GPTNeoXSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`"
"attribute of the `GPTNeoXAttention` class! It will be removed in v4.48"
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
# Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope(
hidden_states=hidden_states,
position_ids=position_ids,
layer_past=layer_past,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
target_dtype = value.dtype
if query.dtype != target_dtype:
query = query.to(target_dtype)
if key.dtype != target_dtype:
key = key.to(target_dtype)
# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=causal_mask,
dropout_p=self.attention_dropout.p if self.training else 0.0,
is_causal=is_causal,
)
# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.dense(attn_output)
return attn_output, present, None
def attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->GPTNeoX
@ -675,6 +663,7 @@ GPT_NEOX_ATTENTION_CLASSES = {
"eager": GPTNeoXAttention,
"flash_attention_2": GPTNeoXFlashAttention2,
"sdpa": GPTNeoXSdpaAttention,
"flex_attention": GPTNeoXAttention,
}
@ -919,7 +908,13 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
converted_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# Flex Attention converts it to a separate mask
if head_mask is not None:
converted_head_mask = ~converted_head_mask.bool() * torch.finfo(inputs_embeds.dtype).min
converted_head_mask = converted_head_mask.to(dtype=self.dtype, device=self.device)
head_mask = converted_head_mask
hidden_states = self.emb_dropout(inputs_embeds)
# create position embeddings to be shared across the decoder layers

View File

@ -206,6 +206,7 @@ from .import_utils import (
is_torch_compile_available,
is_torch_cuda_available,
is_torch_deterministic,
is_torch_flex_attn_available,
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_fx_proxy,

View File

@ -358,6 +358,17 @@ def is_torch_sdpa_available():
return version.parse(_torch_version) >= version.parse("2.1.1")
def is_torch_flex_attn_available():
if not is_torch_available():
return False
elif _torch_version == "N/A":
return False
# TODO check if some bugs cause push backs on the exact version
# NOTE: We require torch>=2.5.0 as it is the first release
return version.parse(_torch_version) >= version.parse("2.5.0")
def is_torchvision_available():
return _torchvision_available
@ -916,6 +927,7 @@ def is_flash_attn_2_available():
return False
@lru_cache()
def is_flash_attn_greater_or_equal_2_10():
if not _is_package_available("flash_attn"):
return False

View File

@ -459,6 +459,31 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase):
self.assertEqual(output_str, expected_output)
@slow
def test_lm_generate_flex_attn_gptneox(self):
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
for checkpointing in [True, False]:
model = GPTNeoXForCausalLM.from_pretrained(
"EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention"
)
self.assertTrue(model.config._attn_implementation == "flex_attention")
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device)
inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
# The hub repo. is updated on 2023-04-04, resulting in poor outputs.
# See: https://github.com/huggingface/transformers/pull/24193
expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure"
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20)
output_str = tokenizer.batch_decode(output_ids)[0]
self.assertEqual(output_str, expected_output)
def pythia_integration_test(self):
model_name_or_path = "EleutherAI/pythia-70m"
model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)