mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Adding Flash Attention 2 Support for GPT2 (#29226)
* First commit to add flash attention 2 for GPT-2 * more improvements * Make GPT2 pass tests and fixed Decison Transformers copies * Fixed missing arg * fix copies * Added expected speedup * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Added test * Fixed attn attribute * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update Decision transformer attentions * More updates * Passing tests * Fix copies * Fix copies part 2 * Decision transformer updates * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix copies * Decision transformer not supporting flash attn * Addressed comments * Addressed comments * Addressed comments --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
3a7e68362b
commit
22d159ddf9
@ -60,6 +60,73 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
|
|||||||
- Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability
|
- Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability
|
||||||
improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only).
|
improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only).
|
||||||
|
|
||||||
|
## Usage example
|
||||||
|
|
||||||
|
The `generate()` method can be used to generate text using GPT2 model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
>>> prompt = "GPT2 is a model developed by OpenAI."
|
||||||
|
|
||||||
|
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
>>> gen_tokens = model.generate(
|
||||||
|
... input_ids,
|
||||||
|
... do_sample=True,
|
||||||
|
... temperature=0.9,
|
||||||
|
... max_length=100,
|
||||||
|
... )
|
||||||
|
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using Flash Attention 2
|
||||||
|
|
||||||
|
Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).
|
||||||
|
|
||||||
|
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
>>> device = "cuda" # the device to load the model onto
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
>>> prompt = "def hello_world():"
|
||||||
|
|
||||||
|
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
|
||||||
|
>>> model.to(device)
|
||||||
|
|
||||||
|
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
|
||||||
|
>>> tokenizer.batch_decode(generated_ids)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Expected speedups
|
||||||
|
|
||||||
|
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `gpt2` checkpoint and the Flash Attention 2 version of the model using a sequence length of 512.
|
||||||
|
|
||||||
|
<div style="text-align: center">
|
||||||
|
<img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg">
|
||||||
|
</div>
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||||
|
@ -42,6 +42,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
|||||||
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
|
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
|
||||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||||
|
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||||
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
|
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
|
||||||
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
|
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
|
||||||
|
@ -108,7 +108,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
|||||||
class DecisionTransformerGPT2Attention(nn.Module):
|
class DecisionTransformerGPT2Attention(nn.Module):
|
||||||
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
@ -146,6 +146,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
|||||||
|
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
self.pruned_heads = set()
|
self.pruned_heads = set()
|
||||||
|
|
||||||
@ -346,6 +347,7 @@ class DecisionTransformerGPT2MLP(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
|
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
|
||||||
class DecisionTransformerGPT2Block(nn.Module):
|
class DecisionTransformerGPT2Block(nn.Module):
|
||||||
|
# Ignore copy
|
||||||
def __init__(self, config, layer_idx=None):
|
def __init__(self, config, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
@ -497,7 +499,6 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|||||||
def set_input_embeddings(self, new_embeddings):
|
def set_input_embeddings(self, new_embeddings):
|
||||||
self.wte = new_embeddings
|
self.wte = new_embeddings
|
||||||
|
|
||||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -548,7 +549,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
# GPT2Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
raise ValueError("batch_size has to be defined and > 0")
|
||||||
|
@ -22,6 +22,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
@ -42,6 +43,8 @@ from ...utils import (
|
|||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_flash_attn_2_available,
|
||||||
|
is_flash_attn_greater_or_equal_2_10,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@ -49,6 +52,11 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
|||||||
from .configuration_gpt2 import GPT2Config
|
from .configuration_gpt2 import GPT2Config
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "openai-community/gpt2"
|
_CHECKPOINT_FOR_DOC = "openai-community/gpt2"
|
||||||
@ -58,6 +66,19 @@ _CONFIG_FOR_DOC = "GPT2Config"
|
|||||||
from ..deprecated._archive_maps import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
from ..deprecated._archive_maps import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||||
|
def _get_unpad_data(attention_mask):
|
||||||
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
return (
|
||||||
|
indices,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||||
"""Load tf checkpoints in a pytorch model"""
|
"""Load tf checkpoints in a pytorch model"""
|
||||||
try:
|
try:
|
||||||
@ -117,7 +138,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
|||||||
class GPT2Attention(nn.Module):
|
class GPT2Attention(nn.Module):
|
||||||
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
@ -155,6 +176,7 @@ class GPT2Attention(nn.Module):
|
|||||||
|
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
self.pruned_heads = set()
|
self.pruned_heads = set()
|
||||||
|
|
||||||
@ -335,6 +357,210 @@ class GPT2Attention(nn.Module):
|
|||||||
return outputs # a, present, (attentions)
|
return outputs # a, present, (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2FlashAttention2(GPT2Attention):
|
||||||
|
"""
|
||||||
|
GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
|
||||||
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||||
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||||
|
bsz, _, _ = hidden_states.size()
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "q_attn"):
|
||||||
|
raise ValueError(
|
||||||
|
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||||
|
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self.q_attn(hidden_states)
|
||||||
|
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
else:
|
||||||
|
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||||
|
|
||||||
|
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
past_key = layer_past[0]
|
||||||
|
past_value = layer_past[1]
|
||||||
|
key = torch.cat((past_key, key), dim=-2)
|
||||||
|
value = torch.cat((past_value, value), dim=-2)
|
||||||
|
|
||||||
|
present = None
|
||||||
|
if use_cache is True:
|
||||||
|
present = (key, value)
|
||||||
|
|
||||||
|
query_length = query.shape[2]
|
||||||
|
tgt_len = key.shape[2]
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
|
||||||
|
key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||||
|
value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
attn_dropout = self.attn_dropout.p if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
if query.dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.c_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query = query.to(target_dtype)
|
||||||
|
key = key.to(target_dtype)
|
||||||
|
value = value.to(target_dtype)
|
||||||
|
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query, key, value, attention_mask, query_length, dropout=attn_dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
|
||||||
|
attn_output = self.c_proj(attn_weights_reshaped)
|
||||||
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
|
outputs = (attn_output, present)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weights_reshaped,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||||
|
def _flash_attention_forward(
|
||||||
|
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||||
|
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_states (`torch.Tensor`):
|
||||||
|
Input query states to be passed to Flash Attention API
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
Input key states to be passed to Flash Attention API
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
Input value states to be passed to Flash Attention API
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||||
|
position of padding tokens and 1 for the position of non-padding tokens.
|
||||||
|
dropout (`float`):
|
||||||
|
Attention dropout
|
||||||
|
softmax_scale (`float`, *optional*):
|
||||||
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||||
|
"""
|
||||||
|
if not self._flash_attn_uses_top_left_mask:
|
||||||
|
causal = self.is_causal
|
||||||
|
else:
|
||||||
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
||||||
|
causal = self.is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
if attention_mask is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_length
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||||
|
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
key_layer = index_first_axis(
|
||||||
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
value_layer = index_first_axis(
|
||||||
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(
|
||||||
|
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||||
|
) # There is a memcpy here, that is very bad.
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
attention_mask = attention_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GPT2MLP(nn.Module):
|
class GPT2MLP(nn.Module):
|
||||||
def __init__(self, intermediate_size, config):
|
def __init__(self, intermediate_size, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -352,18 +578,25 @@ class GPT2MLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
GPT2_ATTENTION_CLASSES = {
|
||||||
|
"eager": GPT2Attention,
|
||||||
|
"flash_attention_2": GPT2FlashAttention2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class GPT2Block(nn.Module):
|
class GPT2Block(nn.Module):
|
||||||
def __init__(self, config, layer_idx=None):
|
def __init__(self, config, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||||
|
attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation]
|
||||||
|
|
||||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.attn = GPT2Attention(config, layer_idx=layer_idx)
|
self.attn = attention_class(config=config, layer_idx=layer_idx)
|
||||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
if config.add_cross_attention:
|
if config.add_cross_attention:
|
||||||
self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
|
self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
|
||||||
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
self.mlp = GPT2MLP(inner_dim, config)
|
self.mlp = GPT2MLP(inner_dim, config)
|
||||||
@ -443,6 +676,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["GPT2Block"]
|
_no_split_modules = ["GPT2Block"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
@ -673,6 +907,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
self.model_parallel = False
|
self.model_parallel = False
|
||||||
self.device_map = None
|
self.device_map = None
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
self._attn_implementation = config._attn_implementation
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@ -790,25 +1025,26 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
# GPT2Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if batch_size <= 0:
|
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
if self._attn_implementation == "flash_attention_2":
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
attention_mask = attention_mask if 0 in attention_mask else None
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
else:
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||||
|
|
||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
@ -817,7 +1053,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
if encoder_attention_mask is None:
|
if encoder_attention_mask is None:
|
||||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
if self._attn_implementation != "flash_attention_2":
|
||||||
|
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
else:
|
else:
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
|
|
||||||
|
@ -19,8 +19,17 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from transformers import GPT2Config, is_torch_available
|
from transformers import GPT2Config, is_torch_available
|
||||||
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
|
require_flash_attn,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -858,3 +867,40 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
"but said in a statement to The Associated Press that"
|
"but said in a statement to The Associated Press that"
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_generate_padding_left(self):
|
||||||
|
"""
|
||||||
|
Overwritting the common test as the test is flaky on tiny models
|
||||||
|
"""
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).to(0)
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
texts = ["hi", "Hello this is a very long sentence"]
|
||||||
|
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
||||||
|
|
||||||
|
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
output_native = tokenizer.batch_decode(output_native)
|
||||||
|
|
||||||
|
model = GPT2LMHeadModel.from_pretrained(
|
||||||
|
"gpt2", device_map={"": 0}, attn_implementation="flash_attention_2", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
|
||||||
|
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
||||||
|
|
||||||
|
expected_output = [
|
||||||
|
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>hi, who was born in the city of Kolkata, was a member of the Kolkata",
|
||||||
|
"Hello this is a very long sentence. I'm sorry. I'm sorry. I'm sorry. I'm sorry. I'm sorry",
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertListEqual(output_native, output_fa_2)
|
||||||
|
self.assertListEqual(output_native, expected_output)
|
||||||
|
Loading…
Reference in New Issue
Block a user