mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add Flash Attention 2 support to Bark (#27364)
* change handmade attention mask to _prepare_4d_attention_mask * add flashattention2 support in Bark * add flashattention2 tests on BarkSemanticModel * make style * fix flashattention and tests + make style * fix memory leak and allow Bark to pass flash attention to sub-models * make style * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * remove unecessary code from tests + justify overriding * Update tests/models/bark/test_modeling_bark.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
ef71673616
commit
a5bee89c9d
@ -26,12 +26,14 @@ from ...generation.logits_process import (
|
||||
BarkEosPrioritizerLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
)
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
|
||||
from ...modeling_utils import PreTrainedModel, get_parameter_device
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_accelerate_available,
|
||||
is_flash_attn_2_available,
|
||||
logging,
|
||||
)
|
||||
from ..auto import AutoModel
|
||||
@ -49,6 +51,11 @@ from .generation_configuration_bark import (
|
||||
)
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -62,6 +69,19 @@ BARK_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
class BarkSelfAttention(nn.Module):
|
||||
# adapted from GPTNeoSelfAttention and Bark code
|
||||
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
|
||||
@ -187,6 +207,177 @@ class BarkSelfAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class BarkSelfFlashAttention2(BarkSelfAttention):
|
||||
"""
|
||||
Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
Splits hidden_size dim into attn_head_size and num_heads
|
||||
"""
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
|
||||
return tensor
|
||||
|
||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
||||
"""
|
||||
# re-assemble all head outputs side by side
|
||||
# (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
|
||||
tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
|
||||
return tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
batch_size, query_len, _ = hidden_states.size()
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
|
||||
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
if past_key_values is not None:
|
||||
# (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features)
|
||||
past_key = past_key_values[0].transpose(1, 2)
|
||||
past_value = past_key_values[1].transpose(1, 2)
|
||||
# and merge on seq_length
|
||||
key = torch.cat((past_key, key), dim=1)
|
||||
value = torch.cat((past_value, value), dim=1)
|
||||
|
||||
if use_cache is True:
|
||||
# (batch, head, seq_length, head_features)
|
||||
present = (key.transpose(1, 2), value.transpose(1, 2))
|
||||
else:
|
||||
present = None
|
||||
|
||||
attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
attn_weights = None
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||
def _flash_attention_forward(
|
||||
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||
|
||||
Args:
|
||||
query_states (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key_states (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value_states (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
attention_mask (`torch.Tensor`):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
dropout (`int`, *optional*):
|
||||
Attention dropout
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||
"""
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=self.is_causal,
|
||||
)
|
||||
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
) # There is a memcpy here, that is very bad.
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
BARK_ATTENTION_CLASSES = {
|
||||
"default": BarkSelfAttention,
|
||||
"flash_attention_2": BarkSelfFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class BarkLayerNorm(nn.Module):
|
||||
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
|
||||
|
||||
@ -229,7 +420,8 @@ class BarkBlock(nn.Module):
|
||||
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
|
||||
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
self.attn = BarkSelfAttention(config, is_causal=is_causal)
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)
|
||||
|
||||
self.mlp = BarkMLP(config)
|
||||
|
||||
@ -277,6 +469,7 @@ class BarkPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = BarkConfig
|
||||
supports_gradient_checkpointing = False
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
@ -596,21 +789,13 @@ class BarkCausalModel(BarkPreTrainedModel):
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
else:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
|
||||
# from_seq_length is 1 to easily broadcast
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@ -1233,10 +1418,12 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
else:
|
||||
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
|
||||
# from_seq_length is 1 to easily broadcast
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
|
||||
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||
|
||||
@ -1669,3 +1856,32 @@ class BarkModel(BarkPreTrainedModel):
|
||||
return audio, output_lengths
|
||||
|
||||
return audio
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_flash_attn_2(
|
||||
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
|
||||
):
|
||||
"""
|
||||
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
|
||||
sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention
|
||||
if necessary.
|
||||
|
||||
If you don't know about Flash Attention, check out the official repository of flash attention:
|
||||
https://github.com/Dao-AILab/flash-attention
|
||||
|
||||
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
|
||||
specific section of the documentation to learn more about it:
|
||||
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
|
||||
|
||||
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
|
||||
half precision and not ran on CPU.
|
||||
|
||||
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
|
||||
can initialize the correct attention module
|
||||
"""
|
||||
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map)
|
||||
|
||||
config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
|
||||
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
|
||||
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
|
||||
return config
|
||||
|
@ -20,6 +20,8 @@ import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from pytest import mark
|
||||
|
||||
from transformers import (
|
||||
BarkCoarseConfig,
|
||||
BarkConfig,
|
||||
@ -33,6 +35,7 @@ from transformers.models.bark.generation_configuration_bark import (
|
||||
BarkSemanticGenerationConfig,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_fp16,
|
||||
require_torch_gpu,
|
||||
@ -872,6 +875,122 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict["input_ids"][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask = dummy_attention_mask[:1]
|
||||
dummy_attention_mask[:, 1:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
|
||||
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
other_inputs = {"output_hidden_states": True}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict["input_ids"][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask = dummy_attention_mask[:1]
|
||||
dummy_attention_mask[:, :-1] = 1
|
||||
dummy_attention_mask[:, -1:] = 0
|
||||
|
||||
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
other_inputs = {
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BarkModelIntegrationTests(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user