mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[Phi] Add support for sdpa (#29108)
This commit is contained in:
parent
7688d8df84
commit
b8b16475d4
@ -172,6 +172,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||||
|
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
||||||
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
||||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||||
|
@ -22,12 +22,16 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@ -39,6 +43,7 @@ from ...utils import (
|
|||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
get_torch_version,
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
is_flash_attn_greater_or_equal_2_10,
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
logging,
|
||||||
@ -617,9 +622,121 @@ class PhiFlashAttention2(PhiAttention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PhiSdpaAttention(PhiAttention):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
||||||
|
|
||||||
|
"""
|
||||||
|
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||||
|
`PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||||
|
SDPA API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Adapted from PhiAttention.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if output_attentions:
|
||||||
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
|
logger.warning_once(
|
||||||
|
"PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
|
||||||
|
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
|
||||||
|
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
|
||||||
|
'be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
return super().forward(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
if self.qk_layernorm:
|
||||||
|
query_states = self.q_layernorm(query_states)
|
||||||
|
key_states = self.k_layernorm(key_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
|
# Partial rotary embedding
|
||||||
|
query_rot, query_pass = (
|
||||||
|
query_states[..., : self.rotary_emb.dim],
|
||||||
|
query_states[..., self.rotary_emb.dim :],
|
||||||
|
)
|
||||||
|
key_rot, key_pass = (
|
||||||
|
key_states[..., : self.rotary_emb.dim],
|
||||||
|
key_states[..., self.rotary_emb.dim :],
|
||||||
|
)
|
||||||
|
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||||
|
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||||
|
|
||||||
|
# [batch_size, seq_length, num_heads, head_dim]
|
||||||
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||||
|
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None:
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
|
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.dense(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
PHI_ATTENTION_CLASSES = {
|
PHI_ATTENTION_CLASSES = {
|
||||||
"eager": PhiAttention,
|
"eager": PhiAttention,
|
||||||
"flash_attention_2": PhiFlashAttention2,
|
"flash_attention_2": PhiFlashAttention2,
|
||||||
|
"sdpa": PhiSdpaAttention,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -714,6 +831,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["PhiDecoderLayer"]
|
_no_split_modules = ["PhiDecoderLayer"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
@ -821,7 +939,9 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
)
|
)
|
||||||
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||||
|
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
@ -895,6 +1015,13 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
if self._use_flash_attention_2:
|
if self._use_flash_attention_2:
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
elif self._use_sdpa and not output_attentions:
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 4d mask is passed through the layers
|
# 4d mask is passed through the layers
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
Loading…
Reference in New Issue
Block a user