mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[FA-2
] Add Flash Attention to Phi
(#27661)
* add FA and modify doc file * test_flash_attn_2_generate_padding_right test overwritten * comment * modify persimmon modeling file * added speedup graph * more changes
This commit is contained in:
parent
06f561687c
commit
f84d85ba67
@ -76,7 +76,7 @@ The original code for Phi-1 and Phi-1.5 can be found [here](https://huggingface.
|
||||
```python
|
||||
>>> from transformers import PhiForCausalLM, AutoTokenizer
|
||||
|
||||
>>> # define the model and tokenzier.
|
||||
>>> # define the model and tokenizer.
|
||||
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
|
||||
|
||||
@ -94,6 +94,46 @@ The original code for Phi-1 and Phi-1.5 can be found [here](https://huggingface.
|
||||
```
|
||||
|
||||
|
||||
## Combining Phi and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import PhiForCausalLM, AutoTokenizer
|
||||
|
||||
>>> # define the model and tokenizer and push the model and tokens to the GPU.
|
||||
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
|
||||
|
||||
>>> # feel free to change the prompt to your liking.
|
||||
>>> prompt = "If I were an AI that had just achieved"
|
||||
|
||||
>>> # apply the tokenizer.
|
||||
>>> tokens = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
>>> # use the model to generate new tokens.
|
||||
>>> generated_output = model.generate(**tokens, use_cache=True, max_new_tokens=10)
|
||||
|
||||
>>> tokenizer.batch_decode(generated_output)[0]
|
||||
'If I were an AI that had just achieved a breakthrough in machine learning, I would be thrilled'
|
||||
```
|
||||
|
||||
### Expected speedups
|
||||
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `susnato/phi-1_dev` checkpoint and the Flash Attention 2 version of the model using a sequence length of 2048.
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/phi_1_speedup_plot.jpg">
|
||||
</div>
|
||||
|
||||
|
||||
## PhiConfig
|
||||
|
||||
[[autodoc]] PhiConfig
|
||||
|
@ -187,6 +187,7 @@ class PersimmonAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
self.is_causal = True
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
|
@ -20,6 +20,7 @@ import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
@ -37,12 +38,19 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_phi import PhiConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "susnato/phi-1_dev"
|
||||
@ -55,6 +63,19 @@ PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
|
||||
class PhiRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
@ -205,6 +226,7 @@ class PhiAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
self.is_causal = True
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
@ -361,10 +383,233 @@ class PhiAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class PhiFlashAttention2(PhiAttention):
|
||||
"""
|
||||
Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays untouched.
|
||||
The only required change would be on the forward pass where it needs to correctly call the public API of flash
|
||||
attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# PhiFlashAttention2 attention does not support output_attentions
|
||||
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# [batch_size, seq_length, 3 x hidden_size]
|
||||
fused_qkv = self.query_key_value(hidden_states)
|
||||
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_states, key_states, value_states) = self._split_heads(fused_qkv)
|
||||
|
||||
if self.qk_layernorm:
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
# [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = query_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
)
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
tgt_len = key_states.shape[2]
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
query_states = query_states.transpose(1, 2).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
|
||||
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32.
|
||||
|
||||
if query_states.dtype == torch.float32:
|
||||
# Handle the case where the model is quantized
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||
def _flash_attention_forward(
|
||||
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||
|
||||
Args:
|
||||
query_states (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key_states (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value_states (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
attention_mask (`torch.Tensor`):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
dropout (`int`, *optional*):
|
||||
Attention dropout
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||
"""
|
||||
if not self._flash_attn_uses_top_left_mask:
|
||||
causal = self.is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
||||
causal = self.is_causal and query_length != 1
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
) # There is a memcpy here, that is very bad.
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
class PhiDecoderLayer(nn.Module):
|
||||
def __init__(self, config: PhiConfig):
|
||||
super().__init__()
|
||||
self.self_attn = PhiAttention(config=config)
|
||||
self.self_attn = (
|
||||
PhiAttention(config=config)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else PhiFlashAttention2(config=config)
|
||||
)
|
||||
self.mlp = PhiMLP(config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
@ -450,6 +695,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -609,14 +855,15 @@ class PhiModel(PhiPreTrainedModel):
|
||||
|
||||
inputs_embeds = self.embed_dropout(inputs_embeds)
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
# Attention mask.
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
|
@ -18,8 +18,17 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import PhiConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -31,6 +40,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
PhiForCausalLM,
|
||||
PhiForSequenceClassification,
|
||||
PhiForTokenClassification,
|
||||
@ -350,6 +360,43 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->susnato/phi-1_5_dev
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = PhiForCausalLM.from_pretrained(
|
||||
"susnato/phi-1_5_dev",
|
||||
load_in_4bit=True,
|
||||
device_map={"": 0},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
|
||||
|
||||
texts = ["hi", "Hello this is a very long sentence"]
|
||||
|
||||
tokenizer.padding_side = "right"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
||||
|
||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_native = tokenizer.batch_decode(output_native)
|
||||
|
||||
model = PhiForCausalLM.from_pretrained(
|
||||
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
||||
|
||||
self.assertListEqual(output_native, output_fa_2)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user