mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Add Flash Attention 2 support to Bark (#27364)
* change handmade attention mask to _prepare_4d_attention_mask * add flashattention2 support in Bark * add flashattention2 tests on BarkSemanticModel * make style * fix flashattention and tests + make style * fix memory leak and allow Bark to pass flash attention to sub-models * make style * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * remove unecessary code from tests + justify overriding * Update tests/models/bark/test_modeling_bark.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
ef71673616
commit
a5bee89c9d
@ -26,12 +26,14 @@ from ...generation.logits_process import (
|
|||||||
BarkEosPrioritizerLogitsProcessor,
|
BarkEosPrioritizerLogitsProcessor,
|
||||||
SuppressTokensLogitsProcessor,
|
SuppressTokensLogitsProcessor,
|
||||||
)
|
)
|
||||||
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
|
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel, get_parameter_device
|
from ...modeling_utils import PreTrainedModel, get_parameter_device
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
|
is_flash_attn_2_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from ..auto import AutoModel
|
from ..auto import AutoModel
|
||||||
@ -49,6 +51,11 @@ from .generation_configuration_bark import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -62,6 +69,19 @@ BARK_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||||
|
def _get_unpad_data(attention_mask):
|
||||||
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||||
|
return (
|
||||||
|
indices,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BarkSelfAttention(nn.Module):
|
class BarkSelfAttention(nn.Module):
|
||||||
# adapted from GPTNeoSelfAttention and Bark code
|
# adapted from GPTNeoSelfAttention and Bark code
|
||||||
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
|
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
|
||||||
@ -187,6 +207,177 @@ class BarkSelfAttention(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BarkSelfFlashAttention2(BarkSelfAttention):
|
||||||
|
"""
|
||||||
|
Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
|
||||||
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||||
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||||
|
"""
|
||||||
|
Splits hidden_size dim into attn_head_size and num_heads
|
||||||
|
"""
|
||||||
|
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||||
|
tensor = tensor.view(new_shape)
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||||
|
"""
|
||||||
|
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
||||||
|
"""
|
||||||
|
# re-assemble all head outputs side by side
|
||||||
|
# (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
|
||||||
|
tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
head_mask=None,
|
||||||
|
use_cache=False,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
batch_size, query_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||||
|
query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
|
||||||
|
|
||||||
|
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
# (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features)
|
||||||
|
past_key = past_key_values[0].transpose(1, 2)
|
||||||
|
past_value = past_key_values[1].transpose(1, 2)
|
||||||
|
# and merge on seq_length
|
||||||
|
key = torch.cat((past_key, key), dim=1)
|
||||||
|
value = torch.cat((past_value, value), dim=1)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
# (batch, head, seq_length, head_features)
|
||||||
|
present = (key.transpose(1, 2), value.transpose(1, 2))
|
||||||
|
else:
|
||||||
|
present = None
|
||||||
|
|
||||||
|
attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout)
|
||||||
|
|
||||||
|
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
|
outputs = (attn_output, present)
|
||||||
|
if output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||||
|
def _flash_attention_forward(
|
||||||
|
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||||
|
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_states (`torch.Tensor`):
|
||||||
|
Input query states to be passed to Flash Attention API
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
Input key states to be passed to Flash Attention API
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
Input value states to be passed to Flash Attention API
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||||
|
position of padding tokens and 1 for the position of non-padding tokens.
|
||||||
|
dropout (`int`, *optional*):
|
||||||
|
Attention dropout
|
||||||
|
softmax_scale (`float`, *optional*):
|
||||||
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||||
|
"""
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
if attention_mask is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_length
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=self.is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||||
|
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
key_layer = index_first_axis(
|
||||||
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
value_layer = index_first_axis(
|
||||||
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(
|
||||||
|
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||||
|
) # There is a memcpy here, that is very bad.
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
attention_mask = attention_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BARK_ATTENTION_CLASSES = {
|
||||||
|
"default": BarkSelfAttention,
|
||||||
|
"flash_attention_2": BarkSelfFlashAttention2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class BarkLayerNorm(nn.Module):
|
class BarkLayerNorm(nn.Module):
|
||||||
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
|
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
|
||||||
|
|
||||||
@ -229,7 +420,8 @@ class BarkBlock(nn.Module):
|
|||||||
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
|
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
|
||||||
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
|
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
self.attn = BarkSelfAttention(config, is_causal=is_causal)
|
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||||
|
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)
|
||||||
|
|
||||||
self.mlp = BarkMLP(config)
|
self.mlp = BarkMLP(config)
|
||||||
|
|
||||||
@ -277,6 +469,7 @@ class BarkPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = BarkConfig
|
config_class = BarkConfig
|
||||||
supports_gradient_checkpointing = False
|
supports_gradient_checkpointing = False
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights."""
|
"""Initialize the weights."""
|
||||||
@ -596,21 +789,13 @@ class BarkCausalModel(BarkPreTrainedModel):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
raise ValueError("batch_size has to be defined and > 0")
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
attention_mask = attention_mask if 0 in attention_mask else None
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
else:
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# from_seq_length is 1 to easily broadcast
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@ -1233,10 +1418,12 @@ class BarkFineModel(BarkPreTrainedModel):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
raise ValueError("batch_size has to be defined and > 0")
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
attention_mask = attention_mask if 0 in attention_mask else None
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
else:
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
|
||||||
|
# from_seq_length is 1 to easily broadcast
|
||||||
|
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
|
||||||
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||||
|
|
||||||
@ -1669,3 +1856,32 @@ class BarkModel(BarkPreTrainedModel):
|
|||||||
return audio, output_lengths
|
return audio, output_lengths
|
||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _check_and_enable_flash_attn_2(
|
||||||
|
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
|
||||||
|
sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention
|
||||||
|
if necessary.
|
||||||
|
|
||||||
|
If you don't know about Flash Attention, check out the official repository of flash attention:
|
||||||
|
https://github.com/Dao-AILab/flash-attention
|
||||||
|
|
||||||
|
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
|
||||||
|
specific section of the documentation to learn more about it:
|
||||||
|
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
|
||||||
|
|
||||||
|
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
|
||||||
|
half precision and not ran on CPU.
|
||||||
|
|
||||||
|
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
|
||||||
|
can initialize the correct attention module
|
||||||
|
"""
|
||||||
|
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map)
|
||||||
|
|
||||||
|
config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
|
||||||
|
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
|
||||||
|
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
|
||||||
|
return config
|
||||||
|
@ -20,6 +20,8 @@ import inspect
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from pytest import mark
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BarkCoarseConfig,
|
BarkCoarseConfig,
|
||||||
BarkConfig,
|
BarkConfig,
|
||||||
@ -33,6 +35,7 @@ from transformers.models.bark.generation_configuration_bark import (
|
|||||||
BarkSemanticGenerationConfig,
|
BarkSemanticGenerationConfig,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_fp16,
|
require_torch_fp16,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@ -872,6 +875,122 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_inference(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_fa = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||||
|
)
|
||||||
|
model_fa.to(torch_device)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
dummy_input = inputs_dict["input_ids"][:1]
|
||||||
|
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||||
|
dummy_input = dummy_input.to(torch.bfloat16)
|
||||||
|
|
||||||
|
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||||
|
|
||||||
|
if dummy_attention_mask is not None:
|
||||||
|
dummy_attention_mask = dummy_attention_mask[:1]
|
||||||
|
dummy_attention_mask[:, 1:] = 1
|
||||||
|
dummy_attention_mask[:, :1] = 0
|
||||||
|
|
||||||
|
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||||
|
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||||
|
|
||||||
|
logits = outputs.hidden_states[-1]
|
||||||
|
logits_fa = outputs_fa.hidden_states[-1]
|
||||||
|
|
||||||
|
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||||
|
|
||||||
|
other_inputs = {"output_hidden_states": True}
|
||||||
|
if dummy_attention_mask is not None:
|
||||||
|
other_inputs["attention_mask"] = dummy_attention_mask
|
||||||
|
|
||||||
|
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||||
|
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||||
|
|
||||||
|
logits = outputs.hidden_states[-1]
|
||||||
|
logits_fa = outputs_fa.hidden_states[-1]
|
||||||
|
|
||||||
|
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
|
||||||
|
|
||||||
|
# check with inference + dropout
|
||||||
|
model.train()
|
||||||
|
_ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_inference_padding_right(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_fa = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||||
|
)
|
||||||
|
model_fa.to(torch_device)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
dummy_input = inputs_dict["input_ids"][:1]
|
||||||
|
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||||
|
dummy_input = dummy_input.to(torch.bfloat16)
|
||||||
|
|
||||||
|
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||||
|
|
||||||
|
if dummy_attention_mask is not None:
|
||||||
|
dummy_attention_mask = dummy_attention_mask[:1]
|
||||||
|
dummy_attention_mask[:, :-1] = 1
|
||||||
|
dummy_attention_mask[:, -1:] = 0
|
||||||
|
|
||||||
|
outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||||
|
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
|
||||||
|
|
||||||
|
logits = outputs.hidden_states[-1]
|
||||||
|
logits_fa = outputs_fa.hidden_states[-1]
|
||||||
|
|
||||||
|
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||||
|
|
||||||
|
other_inputs = {
|
||||||
|
"output_hidden_states": True,
|
||||||
|
}
|
||||||
|
if dummy_attention_mask is not None:
|
||||||
|
other_inputs["attention_mask"] = dummy_attention_mask
|
||||||
|
|
||||||
|
outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||||
|
outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
|
||||||
|
|
||||||
|
logits = outputs.hidden_states[-1]
|
||||||
|
logits_fa = outputs_fa.hidden_states[-1]
|
||||||
|
|
||||||
|
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BarkModelIntegrationTests(unittest.TestCase):
|
class BarkModelIntegrationTests(unittest.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user