mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[GPTNeoX
] Flex Attention + Refactor (#34896)
* gpt neox flex attention + refactor * some formatting * small fix on dropout * add assertion on flex attn test * flaky ci :( * add head mask support * style * handle dtype, replace torch where * fixup flex with output attns * code review and several other fixes * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style * remove unnecessary comment * remove incorrect comment * make flex attn check more agnostic tor versions and centralized * change peft input dtype check to value since q and k could be affected by other stuff like RoPE * i forgor * flaky * code review and small fixes * Update src/transformers/models/gpt_neox/modeling_gpt_neox.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
accb7204f9
commit
46df859975
@ -20,7 +20,10 @@ from typing import Optional, Tuple, TypedDict
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
|
||||
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
@ -180,6 +183,47 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
||||
|
||||
|
||||
def fa_peft_integration_check(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
target_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
PEFT usually casts the layer norms in float32 for training stability reasons
|
||||
therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
cast them back in float16 / bfloat16 just to be sure everything works as expected.
|
||||
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
target_dtype (`torch.dtype`, *optional*):
|
||||
The dtype to convert the attention tensors to. Conversion can be ignored by
|
||||
not providing the target dtype.
|
||||
"""
|
||||
if target_dtype is None:
|
||||
return query, key, value
|
||||
|
||||
input_dtype = value.dtype
|
||||
if input_dtype == torch.float32:
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query = query.to(target_dtype)
|
||||
key = key.to(target_dtype)
|
||||
value = value.to(target_dtype)
|
||||
|
||||
return query, key, value
|
||||
|
||||
|
||||
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
|
||||
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
|
||||
@ -202,6 +246,7 @@ def _flash_attention_forward(
|
||||
cu_seq_lens_k: Optional[torch.LongTensor] = None,
|
||||
max_length_q: Optional[int] = None,
|
||||
max_length_k: Optional[int] = None,
|
||||
target_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
@ -248,6 +293,11 @@ def _flash_attention_forward(
|
||||
if softcap is not None:
|
||||
flash_kwargs["softcap"] = softcap
|
||||
|
||||
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
|
||||
query_states, key_states, value_states = fa_peft_integration_check(
|
||||
query_states, key_states, value_states, target_dtype
|
||||
)
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
|
@ -89,6 +89,7 @@ from .utils import (
|
||||
is_peft_available,
|
||||
is_remote_url,
|
||||
is_safetensors_available,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_sdpa_available,
|
||||
is_torch_xla_available,
|
||||
@ -1342,6 +1343,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# SDPA support
|
||||
_supports_sdpa = False
|
||||
|
||||
# Flex Attention support
|
||||
_supports_flex_attn = False
|
||||
|
||||
# Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
|
||||
_supports_cache_class = False
|
||||
_supports_static_cache = False
|
||||
@ -1548,6 +1552,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
if cls._supports_sdpa:
|
||||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
|
||||
if cls._supports_flex_attn:
|
||||
message += (
|
||||
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
||||
)
|
||||
raise ValueError(message + ".")
|
||||
|
||||
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
|
||||
@ -1582,6 +1590,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
hard_check_only=False,
|
||||
check_device_map=check_device_map,
|
||||
)
|
||||
elif requested_attn_implementation == "flex_attention":
|
||||
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
|
||||
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
|
||||
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
||||
config = cls._check_and_enable_sdpa(
|
||||
@ -1778,7 +1788,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"""
|
||||
Checks the availability of SDPA for a given model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module.
|
||||
"""
|
||||
if hard_check_only:
|
||||
if not cls._supports_sdpa:
|
||||
@ -1803,6 +1813,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config._attn_implementation = "sdpa"
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
|
||||
"""
|
||||
Checks the availability of Flex Attention for a given model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module.
|
||||
"""
|
||||
if hard_check_only:
|
||||
if not cls._supports_flex_attn:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support an attention implementation through torch's flex_attention."
|
||||
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
|
||||
" If you believe this error is a bug, please open an issue in Transformers GitHub repository"
|
||||
' and load your model with the argument `attn_implementation="eager"` meanwhile.'
|
||||
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
|
||||
)
|
||||
if not is_torch_flex_attn_available():
|
||||
raise ImportError(
|
||||
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
|
||||
)
|
||||
|
||||
if not is_torch_flex_attn_available() or not cls._supports_flex_attn:
|
||||
return config
|
||||
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "flex_attention"
|
||||
|
||||
return config
|
||||
|
||||
def enable_input_require_grads(self):
|
||||
"""
|
||||
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
|
||||
|
@ -18,7 +18,6 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
@ -42,9 +41,9 @@ from ...modeling_outputs import (
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
get_torch_version,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
)
|
||||
from .configuration_gpt_neox import GPTNeoXConfig
|
||||
@ -53,6 +52,9 @@ from .configuration_gpt_neox import GPTNeoXConfig
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM"
|
||||
@ -76,6 +78,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -92,6 +95,169 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
query, key, value, attention_mask, head_mask, norm_factor, attention_dropout, training, **_kwargs
|
||||
):
|
||||
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||
key_length = key.size(-2)
|
||||
|
||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
attn_scores = torch.zeros(
|
||||
batch_size * num_attention_heads,
|
||||
query_length,
|
||||
key_length,
|
||||
dtype=query.dtype,
|
||||
device=key.device,
|
||||
)
|
||||
attn_scores = torch.baddbmm(
|
||||
attn_scores,
|
||||
query,
|
||||
key.transpose(1, 2),
|
||||
beta=1.0,
|
||||
alpha=norm_factor,
|
||||
)
|
||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_scores = attn_scores + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=attention_dropout, training=training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def flash_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
norm_factor,
|
||||
attention_dropout,
|
||||
training,
|
||||
target_dtype=None,
|
||||
**_kwargs,
|
||||
):
|
||||
query_length = query.shape[-2]
|
||||
|
||||
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
|
||||
query = query.to(value.dtype)
|
||||
key = key.to(value.dtype)
|
||||
|
||||
# Permute to get the expected shape for Flash Attention
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
attention_dropout = attention_dropout if training else 0.0
|
||||
flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
# Compute attention
|
||||
attn_output = _flash_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
query_length,
|
||||
dropout=attention_dropout,
|
||||
softmax_scale=norm_factor,
|
||||
is_causal=True,
|
||||
use_top_left_mask=flash_attn_uses_top_left_mask,
|
||||
target_dtype=target_dtype,
|
||||
)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def sdpa_attention_forward(query, key, value, attention_mask, attention_dropout, training, **_kwargs):
|
||||
q_len = query.shape[-2]
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
|
||||
query = query.to(value.dtype)
|
||||
key = key.to(value.dtype)
|
||||
|
||||
# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=attention_dropout if training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def flex_attention_forward(query, key, value, attention_mask, head_mask, norm_factor, **_kwargs):
|
||||
causal_mask = attention_mask
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
def causal_mod(score, b, h, q_idx, kv_idx):
|
||||
if causal_mask is not None:
|
||||
score += causal_mask[b][0][q_idx][kv_idx]
|
||||
if head_mask is not None:
|
||||
score += head_mask[b][h][0][0]
|
||||
return score
|
||||
|
||||
attn_output, attn_weights = flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=causal_mod,
|
||||
enable_gqa=True,
|
||||
scale=norm_factor,
|
||||
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
|
||||
# For simplification, we thus always return it as no additional computations are introduced.
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
# lse is returned in float32
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
GPTNEOX_ATTENTION_FUNCTION = {
|
||||
"eager": eager_attention_forward,
|
||||
"flash_attention_2": flash_attention_forward,
|
||||
"sdpa": sdpa_attention_forward,
|
||||
"flex_attention": flex_attention_forward,
|
||||
}
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
@ -147,20 +313,57 @@ class GPTNeoXAttention(nn.Module):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
bsz, seq_len, _ = hidden_states.shape
|
||||
|
||||
# Apply attention-specific projections and rope
|
||||
query, key, value, present = self._attn_projections_and_rope(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
# Checking for fallbacks in case an unsupported feature is requested
|
||||
attention_type = self.config._attn_implementation
|
||||
if (output_attentions or head_mask is not None) and self.config._attn_implementation in [
|
||||
"sdpa",
|
||||
"flash_attention_2",
|
||||
]:
|
||||
logger.warning_once(
|
||||
f"Setting `attention_type` to `eager` because `{attention_type}` does not support"
|
||||
f" `output_attentions=True` or `head_mask`."
|
||||
)
|
||||
attention_type = "eager"
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
||||
elif (
|
||||
self.training
|
||||
and self.config.attention_dropout > 0
|
||||
and self.config._attn_implementation == "flex_attention"
|
||||
):
|
||||
logger.warning_once(
|
||||
f"Setting `attention_type` to `eager` because `dropout` is not supported in `{attention_type}`."
|
||||
)
|
||||
attention_type = "eager"
|
||||
|
||||
# Compute attention
|
||||
attn_output, attn_weights = GPTNEOX_ATTENTION_FUNCTION[attention_type](
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
norm_factor=self.norm_factor,
|
||||
attention_dropout=self.config.attention_dropout,
|
||||
training=self.training,
|
||||
# Flash Attention 2 specific PEFT check
|
||||
target_dtype=self._fa_peft_dtype_check(value),
|
||||
)
|
||||
|
||||
# Reshape outputs and final projection
|
||||
attn_output = attn_output.contiguous()
|
||||
attn_output = attn_output.view(bsz, seq_len, -1)
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
@ -250,262 +453,47 @@ class GPTNeoXAttention(nn.Module):
|
||||
|
||||
return query, key, value, layer_past
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
# compute causal mask from causal mask buffer
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||
key_length = key.size(-2)
|
||||
|
||||
# dynamically increase the causal mask with the key length, if needed.
|
||||
if key_length > self.bias.shape[-1]:
|
||||
self._init_bias(key_length, device=key.device)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
|
||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
attn_scores = torch.zeros(
|
||||
batch_size * num_attention_heads,
|
||||
query_length,
|
||||
key_length,
|
||||
dtype=query.dtype,
|
||||
device=key.device,
|
||||
)
|
||||
attn_scores = torch.baddbmm(
|
||||
attn_scores,
|
||||
query,
|
||||
key.transpose(1, 2),
|
||||
beta=1.0,
|
||||
alpha=self.norm_factor,
|
||||
)
|
||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||
|
||||
mask_value = torch.finfo(attn_scores.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
|
||||
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_scores = attn_scores + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_weights = self.attention_dropout(attn_weights)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output, attn_weights
|
||||
def _fa_peft_dtype_check(self, value):
|
||||
"""
|
||||
PEFT can silently cast the dtype to float32 - this method returns the target dtype to which
|
||||
FA should convert back to (if necessary). For now, we can not move this to the forward pass
|
||||
itself due to the dependency on checking on some part of its own weights (last case).
|
||||
"""
|
||||
target_dtype = None
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
input_dtype = value.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.query_key_value.weight.dtype
|
||||
return target_dtype
|
||||
|
||||
|
||||
# TODO Remove in deprecation cycle
|
||||
class GPTNeoXFlashAttention2(GPTNeoXAttention):
|
||||
"""
|
||||
GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
# Apply attention-specific projections and rope
|
||||
query, key, value, present = self._attn_projections_and_rope(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
logger.warning_once(
|
||||
"The `GPTNeoXFlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
|
||||
"attribute of the `GPTNeoXAttention` class! It will be removed in v4.48"
|
||||
)
|
||||
|
||||
query_length = query.shape[-2]
|
||||
|
||||
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
|
||||
target_dtype = value.dtype
|
||||
if query.dtype != target_dtype:
|
||||
query = query.to(target_dtype)
|
||||
if key.dtype != target_dtype:
|
||||
key = key.to(target_dtype)
|
||||
|
||||
# Permute to get the expected shape for Flash Attention
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 / bfloat16 just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
input_dtype = query.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.query_key_value.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query = query.to(target_dtype)
|
||||
key = key.to(target_dtype)
|
||||
value = value.to(target_dtype)
|
||||
|
||||
attention_dropout = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
# Compute attention
|
||||
attn_weights = _flash_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
query_length,
|
||||
dropout=attention_dropout,
|
||||
softmax_scale=self.norm_factor,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = attn_weights.reshape(
|
||||
attn_weights.shape[0], attn_weights.shape[1], self.num_attention_heads * self.head_size
|
||||
)
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
outputs = (attn_output, layer_past)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# TODO Remove in deprecation cycle
|
||||
class GPTNeoXSdpaAttention(GPTNeoXAttention):
|
||||
"""
|
||||
GPTNeoX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`GPTNeoXAttention` as the weights of the module stays untouched. The only changes are on the forward pass
|
||||
to adapt to the SDPA API.
|
||||
"""
|
||||
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__(config, layer_idx=layer_idx)
|
||||
|
||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
if output_attentions or head_mask is not None:
|
||||
logger.warning_once(
|
||||
"`GPTNeoXSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
||||
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Apply attention-specific projections and rope
|
||||
query, key, value, present = self._attn_projections_and_rope(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
logger.warning_once(
|
||||
"The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`"
|
||||
"attribute of the `GPTNeoXAttention` class! It will be removed in v4.48"
|
||||
)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
|
||||
target_dtype = value.dtype
|
||||
if query.dtype != target_dtype:
|
||||
query = query.to(target_dtype)
|
||||
if key.dtype != target_dtype:
|
||||
key = key.to(target_dtype)
|
||||
|
||||
# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
|
||||
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
return attn_output, present, None
|
||||
|
||||
|
||||
def attention_mask_func(attention_scores, ltor_mask):
|
||||
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
|
||||
return attention_scores
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->GPTNeoX
|
||||
class GPTNeoXRotaryEmbedding(nn.Module):
|
||||
@ -675,6 +663,7 @@ GPT_NEOX_ATTENTION_CLASSES = {
|
||||
"eager": GPTNeoXAttention,
|
||||
"flash_attention_2": GPTNeoXFlashAttention2,
|
||||
"sdpa": GPTNeoXSdpaAttention,
|
||||
"flex_attention": GPTNeoXAttention,
|
||||
}
|
||||
|
||||
|
||||
@ -919,7 +908,13 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
converted_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
# Flex Attention converts it to a separate mask
|
||||
if head_mask is not None:
|
||||
converted_head_mask = ~converted_head_mask.bool() * torch.finfo(inputs_embeds.dtype).min
|
||||
converted_head_mask = converted_head_mask.to(dtype=self.dtype, device=self.device)
|
||||
head_mask = converted_head_mask
|
||||
|
||||
hidden_states = self.emb_dropout(inputs_embeds)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
|
@ -206,6 +206,7 @@ from .import_utils import (
|
||||
is_torch_compile_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_deterministic,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_fx_available,
|
||||
is_torch_fx_proxy,
|
||||
|
@ -358,6 +358,17 @@ def is_torch_sdpa_available():
|
||||
return version.parse(_torch_version) >= version.parse("2.1.1")
|
||||
|
||||
|
||||
def is_torch_flex_attn_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
elif _torch_version == "N/A":
|
||||
return False
|
||||
|
||||
# TODO check if some bugs cause push backs on the exact version
|
||||
# NOTE: We require torch>=2.5.0 as it is the first release
|
||||
return version.parse(_torch_version) >= version.parse("2.5.0")
|
||||
|
||||
|
||||
def is_torchvision_available():
|
||||
return _torchvision_available
|
||||
|
||||
@ -916,6 +927,7 @@ def is_flash_attn_2_available():
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def is_flash_attn_greater_or_equal_2_10():
|
||||
if not _is_package_available("flash_attn"):
|
||||
return False
|
||||
|
@ -459,6 +459,31 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_str, expected_output)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_flex_attn_gptneox(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
||||
for checkpointing in [True, False]:
|
||||
model = GPTNeoXForCausalLM.from_pretrained(
|
||||
"EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention"
|
||||
)
|
||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||
|
||||
if checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
else:
|
||||
model.gradient_checkpointing_disable()
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
|
||||
# The hub repo. is updated on 2023-04-04, resulting in poor outputs.
|
||||
# See: https://github.com/huggingface/transformers/pull/24193
|
||||
expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure"
|
||||
|
||||
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||
output_str = tokenizer.batch_decode(output_ids)[0]
|
||||
|
||||
self.assertEqual(output_str, expected_output)
|
||||
|
||||
def pythia_integration_test(self):
|
||||
model_name_or_path = "EleutherAI/pythia-70m"
|
||||
model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)
|
||||
|
Loading…
Reference in New Issue
Block a user