diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 745a0f98a59..36452aabd4d 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -172,6 +172,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) +* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 799fe02c8f4..b4d261d07f4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -22,12 +22,16 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -39,6 +43,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, @@ -617,9 +622,121 @@ class PhiFlashAttention2(PhiAttention): ) +class PhiSdpaAttention(PhiAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + """ + SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from PhiAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " + "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " + 'be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_emb.dim], + query_states[..., self.rotary_emb.dim :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_emb.dim], + key_states[..., self.rotary_emb.dim :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.dense(attn_output) + + return attn_output, None, past_key_value + + PHI_ATTENTION_CLASSES = { "eager": PhiAttention, "flash_attention_2": PhiFlashAttention2, + "sdpa": PhiSdpaAttention, } @@ -714,6 +831,7 @@ class PhiPreTrainedModel(PreTrainedModel): _no_split_modules = ["PhiDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): @@ -821,7 +939,9 @@ class PhiModel(PhiPreTrainedModel): [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -895,6 +1015,13 @@ class PhiModel(PhiPreTrainedModel): if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask(