F.scaled_dot_product_attention support (#26572)

* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_bart.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/perf_infer_gpu_one.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
fxmarty 2023-12-08 21:38:14 +01:00 committed by GitHub
parent ce0bbd5101
commit 80377eb018
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 2227 additions and 454 deletions

View File

@ -441,7 +441,7 @@ flush()
```
For comparison, let's run the same function, but enable Flash Attention instead.
To do so, we convert the model to [BetterTransformers](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is based on Flash Attention.
To do so, we convert the model to [BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is able to use Flash Attention.
```python
model.to_bettertransformer()

View File

@ -83,10 +83,10 @@ pip install -U flash-attn --no-build-isolation
##### Usage
To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
To load a model using Flash Attention 2, we can pass the `attn_implementation="flash_attention_2"` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
```python
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
```
##### Performance comparison
@ -114,7 +114,7 @@ import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# load in fp16 and use Flash Attention 2
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
# enable CPU offload
model.enable_cpu_offload()

View File

@ -153,7 +153,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> device = "cuda" # the device to load the model onto
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> text = "Replace me by any text you'd like."

View File

@ -59,7 +59,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")
>>> prompt = "def hello_world():"

View File

@ -67,7 +67,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
>>> prompt = "def hello_world():"

View File

@ -77,12 +77,12 @@ pip install -U flash-attn --no-build-isolation
### Usage
To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
```python
>>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
...
```

View File

@ -99,7 +99,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
>>> prompt = "My favourite condiment is"

View File

@ -80,7 +80,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import OPTForCausalLM, GPT2Tokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "

View File

@ -111,7 +111,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import PhiForCausalLM, AutoTokenizer
>>> # define the model and tokenizer and push the model and tokens to the GPU.
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda")
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
>>> # feel free to change the prompt to your liking.
@ -163,4 +163,4 @@ Below is an expected speedup diagram that compares pure inference time between t
- forward
</pt>
</frameworkcontent>
</frameworkcontent>

View File

@ -36,13 +36,29 @@ FlashAttention-2 is experimental and may change considerably in future versions.
1. additionally parallelizing the attention computation over sequence length
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them
FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]:
To enable FlashAttention-2, pass the argument `attn_implementation="flash_attention_2"` to [`~AutoModelForCausalLM.from_pretrained`]:
```python
import torch
@ -54,13 +70,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```
<Tip>
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
Note that `use_flash_attention_2=True` can also be used to enable Flash Attention 2, but is deprecated in favor of `attn_implementation="flash_attention_2"`.
</Tip>
@ -77,14 +95,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
# load in 4bit
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```
@ -124,8 +142,58 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
</div>
## FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available.
For now, Transformers supports inference and training through SDPA for the following architectures:
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
Note that FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type before using it.
By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether a backend is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
```diff
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
# convert the model to BetterTransformer
model.to_bettertransformer()
input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
If you see a bug with the traceback below, try using nightly version of PyTorch which may have broader coverage for FlashAttention:
```bash
RuntimeError: No available kernel. Aborting execution.
# install PyTorch nightly
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```
## BetterTransformer
<Tip warning={true}>
Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers.
</Tip>
<Tip>
Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.
@ -154,39 +222,6 @@ model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")
```
### FlashAttention
SDPA can also call FlashAttention kernels under the hood. FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it.
To enable FlashAttention or to check whether it is available in a given setting (hardware, problem size), use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
```diff
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
# convert the model to BetterTransformer
model.to_bettertransformer()
input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
If you see a bug with the traceback below, try using nightly version of PyTorch which may have broader coverage for FlashAttention:
```bash
RuntimeError: No available kernel. Aborting execution.
# install PyTorch nightly
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```
## bitsandbytes
bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory.

View File

@ -82,7 +82,7 @@ AWQ quantization can also be combined with [FlashAttention-2](perf_infer_gpu_one
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
```

View File

@ -44,7 +44,7 @@ Flash Attention 2は、モデルのdtypeが`fp16`または`bf16`の場合にの
### Quick usage
モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`use_flash_attention_2`を追加します。
モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`attn_implementation="flash_attention_2"`を追加します。
```python
@ -57,7 +57,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```
@ -114,7 +114,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```
@ -132,7 +132,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```
@ -151,7 +151,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
lora_config = LoraConfig(

View File

@ -66,12 +66,12 @@ model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda:0")
### 结合 AWQ 和 Flash Attention
您可以将AWQ量化与Flash Attention结合起来得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`use_flash_attention_2=True`参数。
您可以将AWQ量化与Flash Attention结合起来得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`attn_implementation="flash_attention_2"`参数。
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
```
### 基准测试

View File

@ -236,6 +236,8 @@ class PretrainedConfig(PushToHubMixin):
This attribute is currently not being used during model loading time, but this may change in the future
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
attn_implementation (`str`, *optional*):
The attention implementation to use in the model. Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
> TensorFlow specific parameters
@ -374,6 +376,9 @@ class PretrainedConfig(PushToHubMixin):
# Config hash
self._commit_hash = kwargs.pop("_commit_hash", None)
# Attention implementation to use, if relevant.
self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)
@ -422,6 +427,22 @@ class PretrainedConfig(PushToHubMixin):
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
@ -747,6 +768,9 @@ class PretrainedConfig(PushToHubMixin):
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]
# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
config = cls(**config_dict)
if hasattr(config, "pruned_heads"):
@ -861,8 +885,8 @@ class PretrainedConfig(PushToHubMixin):
self.dict_torch_dtype_to_str(serializable_config_dict)
if "_flash_attn_2_enabled" in serializable_config_dict:
del serializable_config_dict["_flash_attn_2_enabled"]
if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_attn_implementation_internal"]
return serializable_config_dict
@ -880,8 +904,8 @@ class PretrainedConfig(PushToHubMixin):
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
if "_flash_attn_2_enabled" in output:
del output["_flash_attn_2_enabled"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
# Transformers version when serializing the model
output["transformers_version"] = __version__

View File

@ -68,7 +68,7 @@ class AttentionMaskConverter:
key_value_length: int,
dtype: torch.dtype,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
@ -184,6 +184,95 @@ class AttentionMaskConverter:
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@staticmethod
def _unmask_unattended(
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
):
# fmt: off
"""
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
Details: https://github.com/pytorch/pytorch/issues/110213
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
`attention_mask` is [bsz, src_seq_len].
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
For example, if `attention_mask` is
```
[[0, 0, 1],
[1, 1, 1],
[0, 1, 1]]
```
and `expanded_mask` is (e.g. here left-padding case)
```
[[[[0, 0, 0],
[0, 0, 0],
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[0, 0, 0],
[0, 1, 0],
[0, 1, 1]]]]
```
then the modified `expanded_mask` will be
```
[[[[1, 1, 1], <-- modified
[1, 1, 1], <-- modified
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[1, 1, 1], <-- modified
[0, 1, 0],
[0, 1, 1]]]]
```
"""
# fmt: on
# Get the index of the first non-zero value for every sample in the batch.
# In the above example, indices = [[2], [0], [1]]]
tmp = torch.arange(attention_mask.shape[1], 0, -1)
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
# expanded mask will be completely unattended.
left_masked_rows = torch.where(indices > 0)[0]
if left_masked_rows.shape[0] == 0:
return expanded_mask
indices = indices[left_masked_rows]
max_len = torch.max(indices)
range_tensor = torch.arange(max_len).unsqueeze(0)
range_tensor = range_tensor.repeat(indices.size(0), 1)
# Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
range_tensor[range_tensor >= indices] = 0
# TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
if expanded_mask.dim() == 4:
num_masks = expanded_mask.shape[1]
if num_masks == 1:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], 0, range_tensor)
else:
# Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
mask_slice = (
left_masked_rows[:, None, None],
torch.arange(num_masks)[None, :, None],
range_tensor[:, None, :],
)
else:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], range_tensor)
expanded_mask[mask_slice] = unmasked_value
return expanded_mask
def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
@ -225,6 +314,78 @@ def _prepare_4d_causal_attention_mask(
return attention_mask
# Adapted from _prepare_4d_causal_attention_mask
def _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = input_shape[-1] + past_key_values_length
batch_size, query_length = input_shape
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing()
if attention_mask is not None:
if torch.all(attention_mask == 1):
if is_tracing:
pass
elif query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
attention_mask = None
elif key_value_length == query_length:
attention_mask = None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
pass
elif query_length > 1 and key_value_length != query_length:
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
attention_mask = True
elif is_tracing:
raise ValueError(
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
)
if attention_mask is None:
expanded_4d_mask = None
elif attention_mask is True:
expanded_4d_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
else:
expanded_4d_mask = attn_mask_converter.to_4d(
attention_mask,
input_shape[-1],
dtype=inputs_embeds.dtype,
key_value_length=key_value_length,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if query_length > 1:
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
expanded_4d_mask, attention_mask, unmasked_value=0.0
)
return expanded_4d_mask
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
@ -241,13 +402,51 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
batch_size, key_value_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else key_value_length
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing()
if torch.all(mask == 1):
if is_tracing:
pass
elif tgt_len == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
return None
elif key_value_length == tgt_len:
return None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
else:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _create_4d_causal_attention_mask(
input_shape: Union[torch.Size, Tuple, List],
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
) -> Optional[torch.Tensor]:
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`

View File

@ -81,6 +81,7 @@ from .utils import (
is_peft_available,
is_remote_url,
is_safetensors_available,
is_torch_sdpa_available,
is_torch_tpu_available,
logging,
replace_return_docstrings,
@ -1128,6 +1129,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Flash Attention 2 support
_supports_flash_attn_2 = False
# SDPA support
_supports_sdpa = False
# Has support for a `Cache` instance as `past_key_values`
_supports_cache_class = False
@ -1154,7 +1158,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
# Save config and origin of the pretrained weights if given in model
config = self._autoset_attn_implementation(
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
)
self.config = config
self.name_or_path = config.name_or_path
self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
@ -1185,8 +1193,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Args:
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
use_flash_attention_2 (`bool`, *optional*):
Whether to load the model with Flash Attention 2 modules.
"""
torch_dtype = kwargs.pop("torch_dtype", None)
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
@ -1196,8 +1202,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if torch_dtype is not None:
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
if use_flash_attention_2:
config = cls._check_and_enable_flash_attn_2(config, torch_dtype)
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
config._attn_implementation = kwargs.pop("attn_implementation", None)
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, check_device_map=False
)
if is_deepspeed_zero3_enabled():
import deepspeed
@ -1216,6 +1225,67 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model
@classmethod
def _autoset_attn_implementation(
cls,
config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
):
"""
Automatically checks and dispatches to a default attention implementation. In order of priority:
1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
4. The default model's implementation otherwise (`LlamaAttention` for example) .
"""
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
raise ValueError(
f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
raise ValueError(message + ".")
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
hard_check_only = True
else:
hard_check_only = False
if use_flash_attention_2:
logger.warning_once(
'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
)
config._attn_implementation = "flash_attention_2"
if config._attn_implementation == "flash_attention_2":
cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch_dtype,
device_map=device_map,
hard_check_only=hard_check_only,
check_device_map=check_device_map,
)
elif cls._supports_sdpa or config._attn_implementation == "sdpa":
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
elif not hard_check_only:
config._attn_implementation = "eager"
return config
@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
@ -1266,38 +1336,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@classmethod
def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
hard_check_only: bool = False,
) -> PretrainedConfig:
"""
If you don't know about Flash Attention, check out the official repository of flash attention:
https://github.com/Dao-AILab/flash-attention
Checks the availability of Flash Attention 2 and compatibility with the current model.
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
can initialize the correct attention module
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
"""
if not cls._supports_flash_attn_2:
raise ValueError(
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
f"{cls.__name__} does not support Flash Attention 2.0 yet. Please open an issue on GitHub to "
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
)
if not is_flash_attn_2_available():
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
if torch.version.cuda:
if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if torch.version.cuda:
if flash_attention_version < version.parse("2.1.0"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
@ -1305,9 +1370,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
elif torch.version.hip:
if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if flash_attention_version < version.parse("2.0.4"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
@ -1332,20 +1394,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" unexpected behaviour."
)
if device_map is None:
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
if torch.cuda.is_available():
logger.warning(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU"
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. "
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
device_map is not None
check_device_map
and device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
@ -1353,7 +1418,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
config._flash_attn_2_enabled = True
if not hard_check_only:
config._attn_implementation = "flash_attention_2"
return config
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
Checks the availability of SDPA for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_sdpa:
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please open an issue on GitHub to "
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
)
if not is_torch_sdpa_available():
raise ImportError(
"PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
)
if not is_torch_sdpa_available() or not cls._supports_sdpa:
return config
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
if not hard_check_only:
config._attn_implementation = "sdpa"
return config
def enable_input_require_grads(self):
@ -3312,8 +3407,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())
if use_flash_attention_2:
config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
)
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)

View File

@ -389,7 +389,7 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
BARK_ATTENTION_CLASSES = {
"default": BarkSelfAttention,
"eager": BarkSelfAttention,
"flash_attention_2": BarkSelfFlashAttention2,
}
@ -436,8 +436,7 @@ class BarkBlock(nn.Module):
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)
self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal)
self.mlp = BarkMLP(config)
@ -670,6 +669,7 @@ class BarkCausalModel(BarkPreTrainedModel):
self.drop = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)
@ -805,7 +805,7 @@ class BarkCausalModel(BarkPreTrainedModel):
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
attention_mask = attention_mask.view(batch_size, -1)
@ -1265,6 +1265,7 @@ class BarkFineModel(BarkPreTrainedModel):
self.drop = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_final = nn.LayerNorm(config.hidden_size)
@ -1434,7 +1435,7 @@ class BarkFineModel(BarkPreTrainedModel):
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
@ -1875,7 +1876,11 @@ class BarkModel(BarkPreTrainedModel):
@classmethod
def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
hard_check_only: bool = False,
):
"""
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
@ -1892,12 +1897,14 @@ class BarkModel(BarkPreTrainedModel):
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model
can initialize the correct attention module
"""
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map)
config = super()._check_and_enable_flash_attn_2(
config, torch_dtype, device_map, hard_check_only=hard_check_only
)
config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
config.semantic_config._attn_implementation = config._attn_implementation
config.coarse_acoustics_config._attn_implementation = config._attn_implementation
config.fine_acoustics_config._attn_implementation = config._attn_implementation
return config

View File

@ -25,7 +25,12 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -505,8 +510,109 @@ class BartFlashAttention2(BartAttention):
)
class BartSdpaAttention(BartAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
BART_ATTENTION_CLASSES = {
"default": BartAttention,
"eager": BartAttention,
"sdpa": BartSdpaAttention,
"flash_attention_2": BartFlashAttention2,
}
@ -515,9 +621,8 @@ class BartEncoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BART_ATTENTION_CLASSES[attn_type](
self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -587,8 +692,7 @@ class BartDecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BART_ATTENTION_CLASSES[attn_type](
self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -601,7 +705,7 @@ class BartDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BART_ATTENTION_CLASSES[attn_type](
self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -735,6 +839,7 @@ class BartPreTrainedModel(PreTrainedModel):
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.init_std
@ -961,6 +1066,8 @@ class BartEncoder(BartPreTrainedModel):
embed_dim,
)
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
@ -1048,8 +1155,13 @@ class BartEncoder(BartPreTrainedModel):
# expand attention_mask
if attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
@ -1136,6 +1248,9 @@ class BartDecoder(BartPreTrainedModel):
config.d_model,
)
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
@ -1254,9 +1369,18 @@ class BartDecoder(BartPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
@ -1265,8 +1389,17 @@ class BartDecoder(BartPreTrainedModel):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(

View File

@ -252,7 +252,7 @@ class BlenderbotAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
BLENDERBOT_ATTENTION_CLASSES = {"default": BlenderbotAttention}
BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
@ -260,9 +260,8 @@ class BlenderbotEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -332,9 +331,8 @@ class BlenderbotDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -347,7 +345,7 @@ class BlenderbotDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,

View File

@ -254,9 +254,8 @@ class BlenderbotSmallEncoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -321,7 +320,10 @@ class BlenderbotSmallEncoderLayer(nn.Module):
return outputs
BLENDERBOT_SMALL_ATTENTION_CLASSES = {"default": BlenderbotSmallAttention}
# TODO: Implement attention with SDPA for TimeSeriesTransformer.
BLENDERBOT_SMALL_ATTENTION_CLASSES = {
"eager": BlenderbotSmallAttention,
}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
@ -330,8 +332,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -344,7 +345,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,

View File

@ -471,6 +471,12 @@ class FFN(nn.Module):
return x
DISTILBERT_ATTENTION_CLASSES = {
"eager": MultiHeadSelfAttention,
"flash_attention_2": DistilBertFlashAttention2,
}
class TransformerBlock(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
@ -479,11 +485,7 @@ class TransformerBlock(nn.Module):
if config.dim % config.n_heads != 0:
raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
self.attention = (
MultiHeadSelfAttention(config)
if not getattr(config, "_flash_attn_2_enabled", False)
else DistilBertFlashAttention2(config)
)
self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config)
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
self.ffn = FFN(config)
@ -703,6 +705,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# Initialize weights and apply final processing
self.post_init()
@ -808,7 +811,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
if attention_mask is None:

View File

@ -16,7 +16,7 @@
import math
import warnings
from typing import Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@ -24,7 +24,11 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -33,6 +37,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@ -44,6 +49,9 @@ from ...utils import (
from .configuration_falcon import FalconConfig
if TYPE_CHECKING:
from ...configuration_utils import PretrainedConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@ -278,6 +286,7 @@ class FalconAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self._use_sdpa = config._attn_implementation == "sdpa"
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
@ -439,16 +448,15 @@ class FalconAttention(nn.Module):
present = None
if alibi is None:
if hasattr(F, "scaled_dot_product_attention") and not output_attentions:
# TODO: deprecate this once we add FA2 support in Falcon
logger.warning_once(
"The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the"
" future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call "
"`model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations."
)
if self._use_sdpa and not output_attentions:
attn_output = F.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False
query_layer,
key_layer,
value_layer,
attention_mask,
0.0,
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
is_causal=self.is_causal and attention_mask is None and query_length > 1,
)
attention_scores = None
else:
@ -456,58 +464,70 @@ class FalconAttention(nn.Module):
attention_scores /= math.sqrt(self.head_dim)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
attn_output = attention_scores @ value_layer
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
output_tensor = self.dense(attn_output)
attn_output = self.dense(attn_output)
if output_attentions:
return output_tensor, present, attention_scores
return attn_output, present, attention_scores
else:
return output_tensor, present
return attn_output, present
else:
matmul_result = query_layer @ key_layer.transpose(-1, -2)
if self._use_sdpa and not output_attentions and head_mask is None:
attn_output = F.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.attention_dropout.p if self.training else 0.0,
is_causal=self.is_causal and attention_mask is None and query_length > 1,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
attn_output = self.dense(attn_output)
else:
matmul_result = query_layer @ key_layer.transpose(-1, -2)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
# adding (alibi * self.inv_norm_factor) to attention_mask. I think this would be mathematically
# equivalent and more performant, but there might be a numerical difference. If you're reading this
# and you'd like to experiment and maybe file a PR, feel free!
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)
# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = (attention_probs_reshaped @ value_layer).flatten(0, 1)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
output_tensor = self.dense(context_layer)
# matmul: [batch_size * num_heads, q_length, head_dim]
attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)
# change view [batch_size, q_length, num_heads * head_dim]
attn_output = self._merge_heads(attn_output)
attn_output = self.dense(attn_output)
if output_attentions:
return output_tensor, present, attention_probs
return attn_output, present, attention_probs
else:
return output_tensor, present
return attn_output, present
class FalconFlashAttention2(FalconAttention):
@ -734,17 +754,20 @@ class FalconMLP(nn.Module):
return x
FALCON_ATTENTION_CLASSES = {
"eager": FalconAttention,
"sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
"flash_attention_2": FalconFlashAttention2,
}
class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = (
FalconAttention(config)
if not getattr(config, "_flash_attn_2_enabled", False)
else FalconFlashAttention2(config)
)
self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
@ -912,6 +935,7 @@ class FalconPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["FalconDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@ -932,6 +956,25 @@ class FalconPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
# NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0).
if hard_check_only:
if not is_torch_greater_or_equal_than_2_0:
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")
if not is_torch_greater_or_equal_than_2_0:
return config
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
if not hard_check_only:
config._attn_implementation = "sdpa"
return config
@add_start_docstrings(
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
@ -950,6 +993,8 @@ class FalconModel(FalconPreTrainedModel):
# Transformer blocks
self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@ -1003,12 +1048,6 @@ class FalconModel(FalconPreTrainedModel):
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
@ -1047,15 +1086,61 @@ class FalconModel(FalconPreTrainedModel):
)
position_ids = position_ids.unsqueeze(0)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
torch.finfo(alibi.dtype).min,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1:
attention_mask = AttentionMaskConverter._unmask_unattended(
attention_mask, attention_mask_2d, unmasked_value=0.0
)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

View File

@ -22,6 +22,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -128,6 +129,7 @@ class GPTBigCodeAttention(nn.Module):
self.scale_attention_softmax_in_fp32 = (
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
)
self.attn_pdrop = config.attn_pdrop
if self.is_cross_attention:
if self.multi_query:
@ -359,7 +361,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
attn_dropout = self.config.attn_pdrop if self.training else 0.0
attn_dropout = self.attn_pdrop if self.training else 0.0
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
upcast = query.dtype != softmax_dtype
@ -509,6 +511,137 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
)
class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
# The super dispatch is done in the forward.
raise ValueError(
"PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository."
)
scale = None
if not self.scale_attn_weights:
scale = 1
# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key.shape[-2]
if self.multi_query:
query_length = query_shape[1]
# NOTE: Maybe there is better than this?
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention
# and flash attention (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim]
# which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy.
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)
else:
query_length = query_shape[-1]
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_pdrop if self.training else 0.0,
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
is_causal=self.is_causal and attention_mask is None and query_length > 1,
scale=scale,
)
if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
sdpa_result = sdpa_result.transpose(1, 2)
# Reshape is kind of expensive here, as it does a memory copy,
# but I did not manage to make away without it (logits do not match when using view)
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
sdpa_result = sdpa_result.reshape(query_shape)
return sdpa_result, None
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states)
attention_mask = encoder_attention_mask
elif self.multi_query:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
else:
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)
if layer_past is not None:
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
if not output_attentions and head_mask is None:
# Difference with the original implementation: there is no need to transpose the key here,
# as SDPA expects seq_length to be at index -2 for the key as well
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
else:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None."
' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
if self.multi_query:
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
attn_weights = attn_weights.transpose(1, 2)
outputs += (attn_weights,)
return outputs
class GPTBigCodeMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
@ -527,6 +660,13 @@ class GPTBigCodeMLP(nn.Module):
return hidden_states
GPTBIGCODE_ATTENTION_CLASSES = {
"eager": GPTBigCodeAttention,
"flash_attention_2": GPTBigCodeFlashAttention2,
"sdpa": GPTBigCodeSdpaAttention,
}
class GPTBigCodeBlock(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
@ -534,21 +674,19 @@ class GPTBigCodeBlock(nn.Module):
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = (
GPTBigCodeAttention(config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx)
)
self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
if config.multi_query:
raise NotImplementedError("Cross-attention not implemented for MQA")
self.crossattention = (
GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx)
self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](
config, is_cross_attention=True, layer_idx=layer_idx
)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigCodeMLP(self.inner_dim, config)
@ -629,6 +767,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@ -770,6 +909,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
self.gradient_checkpointing = False
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# Initialize weights and apply final processing
self.post_init()
@ -850,7 +992,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
encoder_attention_mask = (
@ -867,7 +1009,34 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# MQA models: (batch_size, query_length, n_heads, key_length)
# MHA models: (batch_size, n_heads, query_length, key_length)
attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
if self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if self.multi_query:
# gpt_bigcode using MQA has the bad taste to use a causal mask with shape
# [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
self_attention_mask = self_attention_mask.transpose(1, 2)
if query_length > 1 and attention_mask is not None:
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
self_attention_mask = AttentionMaskConverter._unmask_unattended(
self_attention_mask, attention_mask, unmasked_value=True
)
# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
dtype = self.wte.weight.dtype
self_attention_mask = torch.where(
self_attention_mask,
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
torch.full(
[], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device
),
)
attention_mask = self_attention_mask
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@ -487,6 +487,12 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
)
GPT_NEO_ATTENTION_CLASSES = {
"eager": GPTNeoSelfAttention,
"flash_attention_2": GPTNeoFlashAttention2,
}
class GPTNeoAttention(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@ -495,11 +501,7 @@ class GPTNeoAttention(nn.Module):
self.attention_type = self.attention_layers[layer_id]
if self.attention_type in ["global", "local"]:
self.attention = (
GPTNeoSelfAttention(config, self.attention_type)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTNeoFlashAttention2(config, self.attention_type)
)
self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](config, self.attention_type)
else:
raise NotImplementedError(
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
@ -718,6 +720,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(float(config.embed_dropout))
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
@ -795,7 +798,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
hidden_states = inputs_embeds + position_embeds
# Attention mask.
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:

View File

@ -658,6 +658,12 @@ class GPTNeoXMLP(nn.Module):
return hidden_states
GPT_NEOX_ATTENTION_CLASSES = {
"eager": GPTNeoXAttention,
"flash_attention_2": GPTNeoXFlashAttention2,
}
class GPTNeoXLayer(nn.Module):
def __init__(self, config):
super().__init__()
@ -666,11 +672,7 @@ class GPTNeoXLayer(nn.Module):
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
self.attention = (
GPTNeoXAttention(config)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTNeoXFlashAttention2(config)
)
self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = GPTNeoXMLP(config)
def forward(
@ -785,6 +787,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.emb_dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False
@ -861,7 +864,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# We create a 3D attention mask from a 2D tensor mask.

View File

@ -29,7 +29,7 @@ from torch.nn import CrossEntropyLoss
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
@ -578,6 +578,7 @@ class IdeficsAttention(nn.Module):
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.dropout = dropout
self.is_causal = True
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
@ -693,6 +694,8 @@ class IdeficsAttention(nn.Module):
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@ -960,6 +963,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of Idefics isn't meant for training from scratch - only
@ -975,6 +979,18 @@ class IdeficsPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
# We remove the checks on `is_torch_sdpa_available()` and `cls._supports_sdpa` as Falcon supports SDPA from torch==2.0.0 (no requirement on 2.1).
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
if not hard_check_only:
config._attn_implementation = "sdpa"
return config
LLAMA_INPUTS_DOCSTRING = r"""
Args:
@ -1240,7 +1256,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

View File

@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
@ -518,7 +519,7 @@ class LlamaFlashAttention2(LlamaAttention):
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = 0.0 if not self.training else self.attention_dropout
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
@ -654,15 +655,99 @@ class LlamaFlashAttention2(LlamaAttention):
)
class LlamaSdpaAttention(LlamaAttention):
"""
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from LlamaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
LLAMA_ATTENTION_CLASSES = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
"sdpa": LlamaSdpaAttention,
}
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
LlamaAttention(config=config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else LlamaFlashAttention2(config=config, layer_idx=layer_idx)
)
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -757,6 +842,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
@ -862,6 +948,8 @@ class LlamaModel(LlamaPreTrainedModel):
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
@ -922,9 +1010,18 @@ class LlamaModel(LlamaPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(

View File

@ -232,9 +232,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.vocab_size
use_flash_attention_2 = getattr(config, "_flash_attn_2_enabled", False)
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, use_flash_attention_2=use_flash_attention_2
config.text_config, attn_implementation=config._attn_implementation
)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()

View File

@ -325,9 +325,8 @@ class M2M100EncoderLayer(nn.Module):
def __init__(self, config: M2M100Config):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -392,7 +391,7 @@ class M2M100EncoderLayer(nn.Module):
return outputs
M2M100_ATTENTION_CLASSES = {"default": M2M100Attention}
M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention}
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
@ -400,9 +399,8 @@ class M2M100DecoderLayer(nn.Module):
def __init__(self, config: M2M100Config):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -415,7 +413,7 @@ class M2M100DecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = M2M100_ATTENTION_CLASSES[attn_type](
self.encoder_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,

View File

@ -272,9 +272,8 @@ class MarianEncoderLayer(nn.Module):
def __init__(self, config: MarianConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -339,7 +338,7 @@ class MarianEncoderLayer(nn.Module):
return outputs
MARIAN_ATTENTION_CLASSES = {"default": MarianAttention}
MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN
@ -348,8 +347,7 @@ class MarianDecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -362,7 +360,7 @@ class MarianDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MARIAN_ATTENTION_CLASSES[attn_type](
self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,

View File

@ -501,7 +501,7 @@ class MBartFlashAttention2(MBartAttention):
MBART_ATTENTION_CLASSES = {
"default": MBartAttention,
"eager": MBartAttention,
"flash_attention_2": MBartFlashAttention2,
}
@ -510,9 +510,8 @@ class MBartEncoderLayer(nn.Module):
def __init__(self, config: MBartConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -581,9 +580,8 @@ class MBartDecoderLayer(nn.Module):
def __init__(self, config: MBartConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -596,7 +594,7 @@ class MBartDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MBART_ATTENTION_CLASSES[attn_type](
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -935,6 +933,7 @@ class MBartEncoder(MBartPreTrainedModel):
embed_dim,
)
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layer_norm = nn.LayerNorm(config.d_model)
@ -1023,7 +1022,7 @@ class MBartEncoder(MBartPreTrainedModel):
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@ -1112,6 +1111,7 @@ class MBartDecoder(MBartPreTrainedModel):
config.d_model,
)
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)
@ -1231,7 +1231,7 @@ class MBartDecoder(MBartPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@ -1242,7 +1242,7 @@ class MBartDecoder(MBartPreTrainedModel):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]

View File

@ -601,15 +601,19 @@ class MistralFlashAttention2(MistralAttention):
)
MISTRAL_ATTENTION_CLASSES = {
"eager": MistralAttention,
"flash_attention_2": MistralFlashAttention2,
}
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
MistralAttention(config=config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else MistralFlashAttention2(config, layer_idx=layer_idx)
)
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -807,6 +811,7 @@ class MistralModel(MistralPreTrainedModel):
self.layers = nn.ModuleList(
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
@ -870,12 +875,7 @@ class MistralModel(MistralPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if (
attention_mask is not None
and hasattr(self.config, "_flash_attn_2_enabled")
and self.config._flash_attn_2_enabled
and use_cache
):
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
@ -884,7 +884,7 @@ class MistralModel(MistralPreTrainedModel):
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:

View File

@ -491,15 +491,18 @@ class OptFlashAttention2(OPTAttention):
)
OPT_ATTENTION_CLASSES = {
"eager": OPTAttention,
"flash_attention_2": OptFlashAttention2,
}
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
if not getattr(config, "_flash_attn_2_enabled", False):
self.self_attn = OPTAttention(config=config, is_decoder=True)
else:
self.self_attn = OptFlashAttention2(config=config, is_decoder=True)
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
@ -732,6 +735,7 @@ class OPTDecoder(OPTPreTrainedModel):
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@ -830,7 +834,7 @@ class OPTDecoder(OPTPreTrainedModel):
mask_seq_length = past_key_values_length + seq_length
# embed positions
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (

View File

@ -267,7 +267,7 @@ class PegasusAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
PEGASUS_ATTENTION_CLASSES = {"default": PegasusAttention}
PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
@ -275,9 +275,8 @@ class PegasusEncoderLayer(nn.Module):
def __init__(self, config: PegasusConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -347,9 +346,8 @@ class PegasusDecoderLayer(nn.Module):
def __init__(self, config: PegasusConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -362,7 +360,7 @@ class PegasusDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,

View File

@ -612,14 +612,16 @@ class PhiFlashAttention2(PhiAttention):
)
PHI_ATTENTION_CLASSES = {
"eager": PhiAttention,
"flash_attention_2": PhiFlashAttention2,
}
class PhiDecoderLayer(nn.Module):
def __init__(self, config: PhiConfig, layer_idx: int):
super().__init__()
self.self_attn = (
PhiAttention(config=config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else PhiFlashAttention2(config=config, layer_idx=layer_idx)
)
self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.mlp = PhiMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
@ -813,6 +815,7 @@ class PhiModel(PhiPreTrainedModel):
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@ -876,7 +879,7 @@ class PhiModel(PhiPreTrainedModel):
inputs_embeds = self.embed_dropout(inputs_embeds)
# Attention mask.
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:

View File

@ -23,7 +23,12 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -265,9 +270,8 @@ class PLBartEncoderLayer(nn.Module):
def __init__(self, config: PLBartConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -332,7 +336,8 @@ class PLBartEncoderLayer(nn.Module):
return outputs
PLBART_ATTENTION_CLASSES = {"default": PLBartAttention}
# TODO: Implement attention with SDPA for PLBart.
PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART
@ -341,8 +346,7 @@ class PLBartDecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -355,7 +359,7 @@ class PLBartDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = PLBART_ATTENTION_CLASSES[attn_type](
self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -670,6 +674,8 @@ class PLBartEncoder(PLBartPreTrainedModel):
embed_dim,
)
self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
@ -757,8 +763,13 @@ class PLBartEncoder(PLBartPreTrainedModel):
# expand attention_mask
if attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
@ -846,6 +857,9 @@ class PLBartDecoder(PLBartPreTrainedModel):
config.d_model,
)
self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
@ -964,9 +978,18 @@ class PLBartDecoder(PLBartPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
@ -975,8 +998,17 @@ class PLBartDecoder(PLBartPreTrainedModel):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(

View File

@ -30,7 +30,12 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_speech_to_text import Speech2TextConfig
@ -326,7 +331,7 @@ class Speech2TextAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
SPEECH_TO_TEXT_ATTENTION_CLASSES = {"default": Speech2TextAttention}
SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT
@ -334,9 +339,8 @@ class Speech2TextEncoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -406,9 +410,8 @@ class Speech2TextDecoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -421,7 +424,7 @@ class Speech2TextDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,

View File

@ -32,7 +32,12 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_time_series_transformer import TimeSeriesTransformerConfig
@ -436,9 +441,8 @@ class TimeSeriesTransformerEncoderLayer(nn.Module):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -503,7 +507,10 @@ class TimeSeriesTransformerEncoderLayer(nn.Module):
return outputs
TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {"default": TimeSeriesTransformerAttention}
# TODO: Implement attention with SDPA for TimeSeriesTransformer.
TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {
"eager": TimeSeriesTransformerAttention,
}
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER
@ -512,8 +519,7 @@ class TimeSeriesTransformerDecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -526,7 +532,7 @@ class TimeSeriesTransformerDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,

View File

@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -690,9 +690,111 @@ class WhisperFlashAttention2(WhisperAttention):
)
class WhisperSdpaAttention(WhisperAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with BART->whisper, Bart->Whisper
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
WHISPER_ATTENTION_CLASSES = {
"default": WhisperAttention,
"eager": WhisperAttention,
"flash_attention_2": WhisperFlashAttention2,
"sdpa": WhisperSdpaAttention,
}
@ -701,9 +803,8 @@ class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@ -773,9 +874,8 @@ class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -788,7 +888,7 @@ class WhisperDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = WHISPER_ATTENTION_CLASSES[attn_type](
self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@ -897,6 +997,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.init_std
@ -1227,6 +1328,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_norm = nn.LayerNorm(config.d_model)
@ -1336,9 +1439,14 @@ class WhisperDecoder(WhisperPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if getattr(self.config, "_flash_attn_2_enabled", False):
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(

View File

@ -107,6 +107,7 @@ from .utils import (
is_torch_fp16_available_on_device,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
@ -440,6 +441,15 @@ def require_flash_attn(test_case):
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
def require_torch_sdpa(test_case):
"""
Decorator marking a test that requires PyTorch's SDPA.
These tests are skipped when requirements are not met (torch version).
"""
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
def require_peft(test_case):
"""
Decorator marking a test that requires PEFT.

View File

@ -180,6 +180,7 @@ from .import_utils import (
is_torch_mps_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,

View File

@ -258,6 +258,19 @@ def get_torch_version():
return _torch_version
def is_torch_sdpa_available():
if not is_torch_available():
return False
elif _torch_version == "N/A":
return False
# NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
# - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
# - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
return version.parse(_torch_version) >= version.parse("2.1.1")
def is_torchvision_available():
return _torchvision_available

View File

@ -890,13 +890,11 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_input = inputs_dict["input_ids"][:1]
@ -949,12 +947,13 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname,
torch_dtype=torch.bfloat16,
)
model.to(torch_device)

View File

@ -319,13 +319,11 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
@ -373,12 +371,13 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname,
torch_dtype=torch.bfloat16,
)
model.to(torch_device)

View File

@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Falcon model. """
import tempfile
import unittest
from parameterized import parameterized
@ -26,7 +27,7 @@ from transformers import (
is_torch_available,
set_seed,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_sdpa, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -437,6 +438,76 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
# NOTE: This check is disabled for Falcon as the non-SDPA/SDPA implementation is in the same class (legacy reason).
# for name, submodule in model_eager.named_modules():
# if "SdpaAttention" in submodule.__class__.__name__:
# raise ValueError("The eager model should not have SDPA attention layers")
# has_sdpa = False
# for name, submodule in model_sdpa.named_modules():
# if "SdpaAttention" in submodule.__class__.__name__:
# has_sdpa = True
# break
# if not has_sdpa:
# raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch
class FalconLanguageGenerationTest(unittest.TestCase):

View File

@ -16,11 +16,14 @@
import unittest
from parameterized import parameterized
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import (
TestCasePlus,
require_bitsandbytes,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@ -309,6 +312,12 @@ class IdeficsModelTester:
def prepare_pixel_values(self):
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
@ -557,6 +566,12 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
model = IdeficsModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch

View File

@ -14,6 +14,7 @@
# limitations under the License.
""" Testing suite for the PyTorch LLaMA model. """
import tempfile
import unittest
import pytest
@ -26,6 +27,7 @@ from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
@ -411,7 +413,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
output_native = tokenizer.batch_decode(output_native)
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@ -419,6 +421,85 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertListEqual(output_native, output_fa_2)
@require_flash_attn
@require_torch_gpu
@slow
def test_use_flash_attention_2_true(self):
"""
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
model.save_pretrained(tmp_dir)
new_model = LlamaForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
).to("cuda")
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
has_flash = False
for name, submodule in new_model.named_modules():
if "FlashAttention" in submodule.__class__.__name__:
has_flash = True
break
if not has_flash:
raise ValueError("The flash model should have flash attention layers")
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
max_new_tokens = 30
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model_sdpa = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"]
for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch
class LlamaIntegrationTest(unittest.TestCase):

View File

@ -387,9 +387,9 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
@ -397,7 +397,10 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
@ -437,7 +440,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
@ -507,7 +510,7 @@ class MistralIntegrationTest(unittest.TestCase):
"mistralai/Mistral-7B-v0.1",
device_map="auto",
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)

View File

@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_native = tokenizer.batch_decode(output_native)
model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)

View File

@ -891,12 +891,13 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname,
torch_dtype=torch.bfloat16,
)
model.to(torch_device)
@ -936,11 +937,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
model.to(torch_device)
dummy_input = inputs_dict[model.main_input_name][:1]
@ -981,6 +982,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
configs_no_init._attn_implementation = "eager"
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
@ -2337,13 +2339,20 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
with torch.no_grad():
outputs = model(**inputs)[0]
input_ids = inputs["input_features"]
del inputs["input_features"]
encoder = model.encoder
encoder_inputs = {"input_features": inputs["input_features"]}
del inputs["input_features"]
if "head_mask" in inputs:
encoder_inputs["head_mask"] = inputs["head_mask"]
if "attention_mask" in inputs:
encoder_inputs["attention_mask"] = inputs["attention_mask"]
if "output_attentions" in inputs:
encoder_inputs["output_attentions"] = inputs["output_attentions"]
with torch.no_grad():
inputs["encoder_outputs"] = encoder(input_ids)
inputs["encoder_outputs"] = encoder(**encoder_inputs)
outputs_embeds = model(**inputs)[0]
self.assertTrue((outputs_embeds == outputs).all())

View File

@ -198,7 +198,14 @@ class ConfigTestUtils(unittest.TestCase):
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
missing_keys,
[
"is_encoder_decoder",
"_name_or_path",
"_commit_hash",
"_attn_implementation_internal",
"transformers_version",
],
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import gc
@ -28,6 +27,7 @@ from collections import defaultdict
from typing import Dict, List, Tuple
import numpy as np
from parameterized import parameterized
from pytest import mark
import transformers
@ -71,6 +71,7 @@ from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_sdpa,
slow,
torch_device,
)
@ -776,102 +777,120 @@ class ModelTesterMixin:
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
for attn_implementation in ["eager", "sdpa"]:
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
continue
main_input_name = model_class.main_input_name
configs_no_init._attn_implementation = attn_implementation
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
main_input = inputs[main_input_name]
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
model(input_ids, bbox)
traced_model = torch.jit.trace(
model, (input_ids, bbox), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
model(main_input)
traced_model = torch.jit.trace(model, main_input)
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
main_input_name = model_class.main_input_name
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
main_input = inputs[main_input_name]
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
model(input_ids, bbox)
traced_model = torch.jit.trace(
model, (input_ids, bbox), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
if model.config._attn_implementation == "sdpa":
trace_input = {main_input_name: main_input}
model.to(torch_device)
model.eval()
if "attention_mask" in inputs:
trace_input["attention_mask"] = inputs["attention_mask"]
else:
self.skipTest("testing SDPA without attention_mask is not supported")
loaded_model.to(torch_device)
loaded_model.eval()
model(main_input, attention_mask=inputs["attention_mask"])
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
else:
model(main_input)
traced_model = torch.jit.trace(model, (main_input,))
except RuntimeError:
self.fail("Couldn't trace module.")
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model.to(torch_device)
model.eval()
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
loaded_model.to(torch_device)
loaded_model.eval()
self.assertTrue(found_buffer)
model_buffers.pop(i)
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
self.assertTrue(models_equal)
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
def test_torch_fx(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -2832,8 +2851,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
import torch
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
@ -2845,7 +2862,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
).to(torch_device)
for _, module in model.named_modules():
@ -2859,8 +2876,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@ -2871,12 +2886,12 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device)
@ -2956,8 +2971,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@ -2968,12 +2981,12 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device)
@ -3049,8 +3062,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_left_padding(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@ -3060,9 +3071,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
@ -3078,7 +3089,10 @@ class ModelTesterMixin:
)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
@ -3092,8 +3106,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@ -3103,9 +3115,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
@ -3121,7 +3133,10 @@ class ModelTesterMixin:
)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
@ -3130,13 +3145,330 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(out, out_fa))
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_device == "cpu" and torch_dtype == "float16":
self.skipTest("float16 not supported on cpu")
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 1e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 1e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [1, 5]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :-1] = 1
dummy_attention_mask[-1, -4:] = 0
elif padding_side == "right":
dummy_attention_mask[-1, 1:] = 1
dummy_attention_mask[-1, :3] = 0
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size]
if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones(
batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
other_inputs = {
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
other_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
if padding_side == "left":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, :-4]
sub_eager = logits_eager[-1, :-4]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, -4:]
# sub_eager = logits_eager[-1, -4:]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
elif padding_side == "right":
sub_sdpa = logits_sdpa[:-1]
sub_eager = logits_eager[:-1]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
sub_sdpa = logits_sdpa[-1, 3:]
sub_eager = logits_eager[-1, 3:]
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
)
# Testing the padding tokens is not really meaningful but anyway
# sub_sdpa = logits_sdpa[-1, :3]
# sub_eager = logits_eager[-1, :3]
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
else:
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
@ -3163,7 +3495,7 @@ class ModelTesterMixin:
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
@ -3182,8 +3514,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_fp32_ln(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@ -3204,7 +3534,7 @@ class ModelTesterMixin:
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
load_in_4bit=True,
)
@ -3282,8 +3612,6 @@ class ModelTesterMixin:
@mark.flash_attn_test
@slow
def test_flash_attn_2_from_config(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
@ -3291,7 +3619,7 @@ class ModelTesterMixin:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes
fa2_model = AutoModelForCausalLM.from_config(
config, use_flash_attention_2=True, torch_dtype=torch.bfloat16
config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
@ -3313,7 +3641,7 @@ class ModelTesterMixin:
model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
self.assertFalse(getattr(model_from_pretrained.config, "_flash_attn_2_enabled", False))
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
fa2_correctly_converted = False

View File

@ -60,7 +60,13 @@ from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available
from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_flax_available,
is_tf_available,
is_torch_sdpa_available,
is_torchdynamo_available,
)
sys.path.append(str(Path(__file__).parent.parent / "utils"))
@ -1689,3 +1695,158 @@ class AttentionMaskTester(unittest.TestCase):
res_compiled = compiled_model(mask, inputs_embeds)
self.assertTrue(torch.equal(res_non_compiled, res_compiled))
@require_torch
@slow
def test_unmask_unattended_left_padding(self):
attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64)
expanded_mask = torch.Tensor(
[
[[[0, 0, 0], [0, 0, 0], [0, 0, 1]]],
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
[[[0, 0, 0], [0, 1, 0], [0, 1, 1]]],
]
).to(torch.int64)
reference_output = torch.Tensor(
[
[[[1, 1, 1], [1, 1, 1], [0, 0, 1]]],
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
[[[1, 1, 1], [0, 1, 0], [0, 1, 1]]],
]
).to(torch.int64)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1)
self.assertTrue(torch.equal(result, reference_output))
attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
min_inf = torch.finfo(torch.float32).min
reference_output = torch.Tensor(
[
[
[
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[min_inf, min_inf, 0, min_inf, min_inf],
[min_inf, min_inf, 0, 0, min_inf],
[min_inf, min_inf, 0, 0, 0],
]
],
[
[
[0, min_inf, min_inf, min_inf, min_inf],
[0, 0, min_inf, min_inf, min_inf],
[0, 0, 0, min_inf, min_inf],
[0, 0, 0, 0, min_inf],
[0, 0, 0, 0, 0],
]
],
[
[
[0, 0, 0, 0, 0],
[min_inf, 0, min_inf, min_inf, min_inf],
[min_inf, 0, 0, min_inf, min_inf],
[min_inf, 0, 0, 0, min_inf],
[min_inf, 0, 0, 0, 0],
]
],
]
)
self.assertTrue(torch.equal(reference_output, result))
@require_torch
@slow
def test_unmask_unattended_right_padding(self):
attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
self.assertTrue(torch.equal(expanded_mask, result))
@require_torch
@slow
def test_unmask_unattended_random_mask(self):
attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64)
attn_mask_converter = AttentionMaskConverter(is_causal=True)
past_key_values_length = 0
key_value_length = attention_mask.shape[-1] + past_key_values_length
expanded_mask = attn_mask_converter.to_4d(
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
)
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
self.assertTrue(torch.equal(expanded_mask, result))
@require_torch
class TestAttentionImplementation(unittest.TestCase):
def test_error_no_sdpa_available(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa")
self.assertTrue(
"does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention"
in str(cm.exception)
)
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
def test_error_no_flash_available(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2"
)
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception))
def test_not_available_flash(self):
if is_flash_attn_2_available():
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
)
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest("This test requires torch<=2.0")
with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="sdpa"
)
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))

View File

@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import doctest
import logging
import os
import unittest
from glob import glob
from pathlib import Path
from typing import List, Union
@ -27,6 +27,63 @@ from transformers.testing_utils import require_tf, require_torch, slow
logger = logging.getLogger()
@require_torch
class TestDocLists(unittest.TestCase):
def test_flash_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
doctext = doctext.split("You can request to add FlashAttention-2 support")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_fa2 = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_flash_attn_2 = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_fa2.append(model_name)
for arch in archs_supporting_fa2:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
)
def test_sdpa_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split(
"For now, Transformers supports inference and training through SDPA for the following architectures:"
)[1]
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_sdpa = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_sdpa = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)
@unittest.skip("Temporarily disable the doc tests.")
@require_torch
@require_tf