Add Flash Attention 2 support to Bark (#27364)

* change handmade attention mask to _prepare_4d_attention_mask

* add flashattention2 support in Bark

* add flashattention2 tests on BarkSemanticModel

* make style

* fix flashattention and tests + make style

* fix memory leak and allow Bark to pass flash attention to sub-models

* make style

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unecessary code from tests + justify overriding

* Update tests/models/bark/test_modeling_bark.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe 2023-11-08 17:06:35 +00:00 committed by GitHub
parent ef71673616
commit a5bee89c9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 355 additions and 20 deletions

View File

@ -26,12 +26,14 @@ from ...generation.logits_process import (
BarkEosPrioritizerLogitsProcessor,
SuppressTokensLogitsProcessor,
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_flash_attn_2_available,
logging,
)
from ..auto import AutoModel
@ -49,6 +51,11 @@ from .generation_configuration_bark import (
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
@ -62,6 +69,19 @@ BARK_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class BarkSelfAttention(nn.Module):
# adapted from GPTNeoSelfAttention and Bark code
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
@ -187,6 +207,177 @@ class BarkSelfAttention(nn.Module):
return outputs
class BarkSelfFlashAttention2(BarkSelfAttention):
"""
Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
return tensor
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
# re-assemble all head outputs side by side
# (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
return tensor
def forward(
self,
hidden_states,
attention_mask=None,
past_key_values=None,
head_mask=None,
use_cache=False,
output_attentions=False,
):
batch_size, query_len, _ = hidden_states.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if past_key_values is not None:
# (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features)
past_key = past_key_values[0].transpose(1, 2)
past_value = past_key_values[1].transpose(1, 2)
# and merge on seq_length
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache is True:
# (batch, head, seq_length, head_features)
present = (key.transpose(1, 2), value.transpose(1, 2))
else:
present = None
attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
attn_weights = None
outputs += (attn_weights,)
return outputs
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
BARK_ATTENTION_CLASSES = {
"default": BarkSelfAttention,
"flash_attention_2": BarkSelfFlashAttention2,
}
class BarkLayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
@ -229,7 +420,8 @@ class BarkBlock(nn.Module):
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
self.attn = BarkSelfAttention(config, is_causal=is_causal)
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)
self.mlp = BarkMLP(config)
@ -277,6 +469,7 @@ class BarkPreTrainedModel(PreTrainedModel):
config_class = BarkConfig
supports_gradient_checkpointing = False
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights."""
@ -596,21 +789,13 @@ class BarkCausalModel(BarkPreTrainedModel):
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = attention_mask if 0 in attention_mask else None
else:
attention_mask = attention_mask.view(batch_size, -1)
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
# from_seq_length is 1 to easily broadcast
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -1233,10 +1418,12 @@ class BarkFineModel(BarkPreTrainedModel):
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
# from_seq_length is 1 to easily broadcast
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
@ -1669,3 +1856,32 @@ class BarkModel(BarkPreTrainedModel):
return audio, output_lengths
return audio
@classmethod
def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
):
"""
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention
if necessary.
If you don't know about Flash Attention, check out the official repository of flash attention:
https://github.com/Dao-AILab/flash-attention
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
can initialize the correct attention module
"""
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map)
config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
return config

View File

@ -20,6 +20,8 @@ import inspect
import tempfile
import unittest
from pytest import mark
from transformers import (
BarkCoarseConfig,
BarkConfig,
@ -33,6 +35,7 @@ from transformers.models.bark.generation_configuration_bark import (
BarkSemanticGenerationConfig,
)
from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_fp16,
require_torch_gpu,
@ -872,6 +875,122 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
dummy_input = inputs_dict["input_ids"][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
dummy_attention_mask[:, 1:] = 1
dummy_attention_mask[:, :1] = 0
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
other_inputs = {"output_hidden_states": True}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
dummy_input = inputs_dict["input_ids"][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
other_inputs = {
"output_hidden_states": True,
}
if dummy_attention_mask is not None:
other_inputs["attention_mask"] = dummy_attention_mask
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_torch
class BarkModelIntegrationTests(unittest.TestCase):