diff --git a/docs/source/en/model_doc/distilbert.md b/docs/source/en/model_doc/distilbert.md index 7633adcae42..5742380c517 100644 --- a/docs/source/en/model_doc/distilbert.md +++ b/docs/source/en/model_doc/distilbert.md @@ -133,6 +133,37 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h - A blog post on how to [deploy DistilBERT with Amazon SageMaker](https://huggingface.co/blog/deploy-hugging-face-models-easily-with-amazon-sagemaker). - A blog post on how to [Deploy BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module](https://www.philschmid.de/terraform-huggingface-amazon-sagemaker). + +## Combining DistilBERT and Flash Attention 2 + +First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`) + +To load and run a model using Flash Attention 2, refer to the snippet below: + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModel + +>>> device = "cuda" # the device to load the model onto + +>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased') +>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True) + +>>> text = "Replace me by any text you'd like." + +>>> encoded_input = tokenizer(text, return_tensors='pt').to(device) +>>> model.to(device) + +>>> output = model(**encoded_input) +``` + + ## DistilBertConfig [[autodoc]] DistilBertConfig diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index c66519a7245..144fde42e0b 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union import numpy as np import torch +import torch.nn.functional as F from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -44,12 +45,18 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, logging, replace_return_docstrings, ) from .configuration_distilbert import DistilBertConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "distilbert-base-uncased" _CONFIG_FOR_DOC = "DistilBertConfig" @@ -69,6 +76,19 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ # UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE # +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): if is_deepspeed_zero3_enabled(): import deepspeed @@ -141,10 +161,12 @@ class Embeddings(nn.Module): class MultiHeadSelfAttention(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() + self.config = config self.n_heads = config.n_heads self.dim = config.dim self.dropout = nn.Dropout(p=config.attention_dropout) + self.is_causal = False # Have an even number of multi heads that divide the dimensions if self.dim % self.n_heads != 0: @@ -240,6 +262,178 @@ class MultiHeadSelfAttention(nn.Module): return (context,) +class DistilBertFlashAttention2(MultiHeadSelfAttention): + """ + DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + batch_size, q_length, dim = query.size() + + dim_per_head = self.dim // self.n_heads + + def reshape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(batch_size, -1, self.n_heads, dim_per_head) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = reshape(self.q_lin(query)) + key_states = reshape(self.k_lin(key)) + value_states = reshape(self.v_lin(value)) + + attn_dropout = self.config.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query_states.dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_lin.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_weights = self._flash_attention_forward( + query_states, key_states, value_states, mask, q_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head) + attn_output = self.out_lin(attn_weights_reshaped) + + if output_attentions: + return (attn_output, attn_weights) + else: + return (attn_output,) + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class FFN(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() @@ -269,7 +463,11 @@ class TransformerBlock(nn.Module): if config.dim % config.n_heads != 0: raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly") - self.attention = MultiHeadSelfAttention(config) + self.attention = ( + MultiHeadSelfAttention(config) + if not getattr(config, "_flash_attn_2_enabled", False) + else DistilBertFlashAttention2(config) + ) self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.ffn = FFN(config) @@ -407,6 +605,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): load_tf_weights = None base_model_prefix = "distilbert" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True def _init_weights(self, module: nn.Module): """Initialize the weights.""" @@ -588,14 +787,17 @@ class DistilBertModel(DistilBertPreTrainedModel): device = input_ids.device if input_ids is not None else inputs_embeds.device - if attention_mask is None: - attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) + if getattr(self.config, "_flash_attn_2_enabled", False): + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) + return self.transformer( x=embeddings, attn_mask=attention_mask, diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py index 22e97653536..b6d3c0f57aa 100644 --- a/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/models/distilbert/test_modeling_distilbert.py @@ -16,8 +16,10 @@ import os import tempfile import unittest +from pytest import mark + from transformers import DistilBertConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device +from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -285,6 +287,114 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) + # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test. + @require_flash_attn + @require_torch_accelerator + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference(self): + import torch + + for model_class in self.all_model_classes: + dummy_input = torch.LongTensor( + [ + [1, 2, 3, 4], + [1, 2, 8, 9], + [1, 2, 11, 12], + [1, 2, 13, 14], + ] + ).to(torch_device) + dummy_attention_mask = torch.LongTensor( + [ + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + ] + ).to(torch_device) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + ) + model.to(torch_device) + + logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] + logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + + output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits_fa = output_fa.hidden_states[-1] + + output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits = output.hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)) + + # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test. + @require_flash_attn + @require_torch_accelerator + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_padding_right(self): + import torch + + for model_class in self.all_model_classes: + dummy_input = torch.LongTensor( + [ + [1, 2, 3, 4], + [1, 2, 8, 9], + [1, 2, 11, 12], + [1, 2, 13, 14], + ] + ).to(torch_device) + dummy_attention_mask = torch.LongTensor( + [ + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + ] + ).to(torch_device) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + ) + model.to(torch_device) + + logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] + logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + + output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits_fa = output_fa.hidden_states[-1] + + output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits = output.hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)) + @require_torch class DistilBertModelIntergrationTest(unittest.TestCase):