Add flash attention for gpt_bigcode (#26479)

* added flash attention of gpt_bigcode

* changed docs

* Update src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py

* add FA-2 docs

* oops

* Update docs/source/en/perf_infer_gpu_one.md Last Nit

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* oops

* remove padding_mask

* change getattr->hasattr logic

* changed .md file

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Susnato Dhar 2023-10-31 16:51:02 +05:30 committed by GitHub
parent 9dc4ce9ea7
commit b5db8ca66f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 328 additions and 23 deletions

View File

@ -42,6 +42,45 @@ The main differences compared to GPT2.
You can read more about the optimizations in the [original pull request](https://github.com/huggingface/transformers/pull/22575)
## Combining Starcoder and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
```bash
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
To load and run a model using Flash Attention 2, refer to the snippet below:
```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")
>>> prompt = "def hello_world():"
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False)
>>> tokenizer.batch_decode(generated_ids)[0]
'def hello_world():\n print("hello world")\n\nif __name__ == "__main__":\n print("hello world")\n<|endoftext|>'
```
### Expected speedups
Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `bigcode/starcoder` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/starcoder-speedup.png">
</div>
## GPTBigCodeConfig
[[autodoc]] GPTBigCodeConfig

View File

@ -34,6 +34,7 @@ We natively support Flash Attention 2 for the following models:
- Llama
- Mistral
- Falcon
- [GPTBigCode (Starcoder)](model_doc/gpt_bigcode#)
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*

View File

@ -16,6 +16,7 @@ import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -32,11 +33,17 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
)
from .configuration_gpt_bigcode import GPTBigCodeConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder"
@ -78,11 +85,25 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor
return x
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class GPTBigCodeAttention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
self.mask_value = None
self.config = config
self.mask_value = None
self.multi_query = config.multi_query
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
@ -90,6 +111,8 @@ class GPTBigCodeAttention(nn.Module):
self.kv_heads = 1 if self.multi_query else self.num_heads
self.kv_dim = self.kv_heads * self.head_dim
self.split_size = self.embed_dim
self.is_causal = True
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
@ -212,10 +235,16 @@ class GPTBigCodeAttention(nn.Module):
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
if "padding_mask" in kwargs:
logger.warning_once(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
raise ValueError(
@ -262,6 +291,223 @@ class GPTBigCodeAttention(nn.Module):
return outputs # a, present, (attentions)
class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
"""
GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
API of flash attention and deal with padding tokens in case the input contains any of them.
"""
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
if "padding_mask" in kwargs:
logger.warning_once(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states)
attention_mask = encoder_attention_mask
elif self.multi_query:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
else:
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)
if layer_past is not None:
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
if self.multi_query:
batch_size, query_length, _ = query.shape
query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim)
key = key.unsqueeze(2)
value = value.unsqueeze(2)
else:
query_length = query.shape[2]
batch_size, _, tgt, _ = key.shape
query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim)
key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
attn_dropout = self.dropout if self.training else 0.0
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
upcast = query.dtype != softmax_dtype
softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
softmax_scale = softmax_scale**-1
if self.scale_attn_weights:
softmax_scale /= self.head_dim**0.5
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.c_attn.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
attn_output = self._flash_attention_forward(
query, key, value, attention_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale
)
attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
attn_output = self.c_proj(attn_weights_reshaped)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
if self.multi_query:
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2)
else:
attn_weights_reshaped = None
outputs += (attn_weights_reshaped,)
return outputs # a, present, (attentions)
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class GPTBigCodeMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
@ -287,13 +533,21 @@ class GPTBigCodeBlock(nn.Module):
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx)
self.attn = (
GPTBigCodeAttention(config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx)
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
if config.multi_query:
raise NotImplementedError("Cross-attention not implemented for MQA")
self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
self.crossattention = (
GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx)
)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigCodeMLP(self.inner_dim, config)
@ -373,6 +627,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@ -594,28 +849,38 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
if attention_mask is not None:
self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
dtype=torch.bool, device=self_attention_mask.device
if getattr(self.config, "_flash_attn_2_enabled", False):
# 2d mask is passed through the layers
attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
encoder_attention_mask = (
encoder_attention_mask.bool()
if (encoder_attention_mask is not None and 0 in encoder_attention_mask)
else None
)
# MQA models: (batch_size, query_length, n_heads, key_length)
# MHA models: (batch_size, n_heads, query_length, key_length)
attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if (
self.config.add_cross_attention
and encoder_hidden_states is not None
and encoder_attention_mask is not None
):
if encoder_attention_mask.dim() == 2:
encoder_attention_mask.unsqueeze(1)
assert encoder_attention_mask.dim() == 3
encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
else:
encoder_attention_mask = None
# 4d mask is passed through the layers
if attention_mask is not None:
self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
dtype=torch.bool, device=self_attention_mask.device
)
# MQA models: (batch_size, query_length, n_heads, key_length)
# MHA models: (batch_size, n_heads, query_length, key_length)
attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if (
self.config.add_cross_attention
and encoder_hidden_states is not None
and encoder_attention_mask is not None
):
if encoder_attention_mask.dim() == 2:
encoder_attention_mask.unsqueeze(1)
assert encoder_attention_mask.dim() == 3
encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head