mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
F.scaled_dot_product_attention support (#26572)
* add sdpa * wip * cleaning * add ref * yet more cleaning * and more :) * wip llama * working llama * add output_attentions=True support * bigcode sdpa support * fixes * gpt-bigcode support, require torch>=2.1.1 * add falcon support * fix conflicts falcon * style * fix attention_mask definition * remove output_attentions from attnmaskconverter * support whisper without removing any Copied from statement * fix mbart default to eager renaming * fix typo in falcon * fix is_causal in SDPA * check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained * add warnings when falling back on the manual implementation * precise doc * wip replace _flash_attn_enabled by config.attn_implementation * fix typo * add tests * style * add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace * obey to config.attn_implementation if a config is passed in from_pretrained * fix is_torch_sdpa_available when torch is not installed * remove dead code * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/bart/modeling_bart.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * remove duplicate pretraining_tp code * add dropout in llama * precise comment on attn_mask * add fmt: off for _unmask_unattended docstring * precise num_masks comment * nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion * cleanup modeling_utils * backward compatibility * fix style as requested * style * improve documentation * test pass * style * add _unmask_unattended tests * skip meaningless tests for idefics * hard_check SDPA requirements when specifically requested * standardize the use if XXX_ATTENTION_CLASSES * fix SDPA bug with mem-efficient backend on CUDA when using fp32 * fix test * rely on SDPA is_causal parameter to handle the causal mask in some cases * fix FALCON_ATTENTION_CLASSES * remove _flash_attn_2_enabled occurences * fix test * add OPT to the list of supported flash models * improve test * properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test * remove remaining _flash_attn_2_enabled occurence * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_attn_mask_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/perf_infer_gpu_one.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * remove use_attn_implementation * fix docstring & slight bug * make attn_implementation internal (_attn_implementation) * typos * fix tests * deprecate use_flash_attention_2=True * fix test * add back llama that was removed by mistake * fix tests * remove _flash_attn_2_enabled occurences bis * add check & test that passed attn_implementation is valid * fix falcon torchscript export * fix device of mask in tests * add tip about torch.jit.trace and move bt doc below sdpa * fix parameterized.expand order * move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there * update sdpaattention class with the new cache * Update src/transformers/configuration_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/bark/modeling_bark.py * address review comments * WIP torch.jit.trace fix. left: test both eager & sdpa * add test for torch.jit.trace for both eager/sdpa * fix falcon with torch==2.0 that needs to use sdpa * fix doc * hopefully last fix * fix key_value_length that has no default now in mask converter * is it flacky? * fix speculative decoding bug * tests do pass * fix following #27907 --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
ce0bbd5101
commit
80377eb018
@ -441,7 +441,7 @@ flush()
|
||||
```
|
||||
|
||||
For comparison, let's run the same function, but enable Flash Attention instead.
|
||||
To do so, we convert the model to [BetterTransformers](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is based on Flash Attention.
|
||||
To do so, we convert the model to [BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is able to use Flash Attention.
|
||||
|
||||
```python
|
||||
model.to_bettertransformer()
|
||||
|
@ -83,10 +83,10 @@ pip install -U flash-attn --no-build-isolation
|
||||
|
||||
##### Usage
|
||||
|
||||
To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
|
||||
To load a model using Flash Attention 2, we can pass the `attn_implementation="flash_attention_2"` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
|
||||
|
||||
```python
|
||||
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
|
||||
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
|
||||
```
|
||||
|
||||
##### Performance comparison
|
||||
@ -114,7 +114,7 @@ import torch
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# load in fp16 and use Flash Attention 2
|
||||
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
|
||||
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
|
||||
|
||||
# enable CPU offload
|
||||
model.enable_cpu_offload()
|
||||
|
@ -153,7 +153,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
||||
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
|
||||
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||
|
||||
>>> text = "Replace me by any text you'd like."
|
||||
|
||||
|
@ -59,7 +59,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True)
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")
|
||||
|
||||
>>> prompt = "def hello_world():"
|
||||
|
@ -67,7 +67,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, use_flash_attention_2=True)
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
||||
|
||||
>>> prompt = "def hello_world():"
|
||||
|
@ -77,12 +77,12 @@ pip install -U flash-attn --no-build-isolation
|
||||
|
||||
### Usage
|
||||
|
||||
To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
|
||||
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
|
||||
|
||||
```python
|
||||
>>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
|
||||
|
||||
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
|
||||
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
|
||||
...
|
||||
```
|
||||
|
||||
|
@ -99,7 +99,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
|
||||
>>> prompt = "My favourite condiment is"
|
||||
|
@ -80,7 +80,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
>>> from transformers import OPTForCausalLM, GPT2Tokenizer
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True)
|
||||
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
|
||||
|
@ -111,7 +111,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
>>> from transformers import PhiForCausalLM, AutoTokenizer
|
||||
|
||||
>>> # define the model and tokenizer and push the model and tokens to the GPU.
|
||||
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda")
|
||||
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
|
||||
|
||||
>>> # feel free to change the prompt to your liking.
|
||||
@ -163,4 +163,4 @@ Below is an expected speedup diagram that compares pure inference time between t
|
||||
- forward
|
||||
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
</frameworkcontent>
|
||||
|
@ -36,13 +36,29 @@ FlashAttention-2 is experimental and may change considerably in future versions.
|
||||
1. additionally parallelizing the attention computation over sequence length
|
||||
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them
|
||||
|
||||
FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
|
||||
FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
|
||||
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
|
||||
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
|
||||
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||
|
||||
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
|
||||
|
||||
Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
|
||||
|
||||
FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
|
||||
|
||||
To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]:
|
||||
To enable FlashAttention-2, pass the argument `attn_implementation="flash_attention_2"` to [`~AutoModelForCausalLM.from_pretrained`]:
|
||||
|
||||
```python
|
||||
import torch
|
||||
@ -54,13 +70,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
|
||||
|
||||
Note that `use_flash_attention_2=True` can also be used to enable Flash Attention 2, but is deprecated in favor of `attn_implementation="flash_attention_2"`.
|
||||
|
||||
</Tip>
|
||||
|
||||
@ -77,14 +95,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_8bit=True,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
# load in 4bit
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_4bit=True,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
```
|
||||
|
||||
@ -124,8 +142,58 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
|
||||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
|
||||
</div>
|
||||
|
||||
## FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention
|
||||
|
||||
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available.
|
||||
|
||||
For now, Transformers supports inference and training through SDPA for the following architectures:
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||
|
||||
Note that FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type before using it.
|
||||
|
||||
By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether a backend is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
|
||||
|
||||
```diff
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
|
||||
# convert the model to BetterTransformer
|
||||
model.to_bettertransformer()
|
||||
|
||||
input_text = "Hello my dog is cute and"
|
||||
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
outputs = model.generate(**inputs)
|
||||
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
If you see a bug with the traceback below, try using nightly version of PyTorch which may have broader coverage for FlashAttention:
|
||||
|
||||
```bash
|
||||
RuntimeError: No available kernel. Aborting execution.
|
||||
|
||||
# install PyTorch nightly
|
||||
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
|
||||
```
|
||||
|
||||
## BetterTransformer
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.
|
||||
@ -154,39 +222,6 @@ model = model.reverse_bettertransformer()
|
||||
model.save_pretrained("saved_model")
|
||||
```
|
||||
|
||||
### FlashAttention
|
||||
|
||||
SDPA can also call FlashAttention kernels under the hood. FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it.
|
||||
|
||||
To enable FlashAttention or to check whether it is available in a given setting (hardware, problem size), use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
|
||||
|
||||
```diff
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
|
||||
# convert the model to BetterTransformer
|
||||
model.to_bettertransformer()
|
||||
|
||||
input_text = "Hello my dog is cute and"
|
||||
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
outputs = model.generate(**inputs)
|
||||
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
If you see a bug with the traceback below, try using nightly version of PyTorch which may have broader coverage for FlashAttention:
|
||||
|
||||
```bash
|
||||
RuntimeError: No available kernel. Aborting execution.
|
||||
|
||||
# install PyTorch nightly
|
||||
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
|
||||
```
|
||||
|
||||
## bitsandbytes
|
||||
|
||||
bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory.
|
||||
|
@ -82,7 +82,7 @@ AWQ quantization can also be combined with [FlashAttention-2](perf_infer_gpu_one
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
|
||||
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
|
||||
```
|
||||
|
||||
|
||||
|
@ -44,7 +44,7 @@ Flash Attention 2は、モデルのdtypeが`fp16`または`bf16`の場合にの
|
||||
|
||||
### Quick usage
|
||||
|
||||
モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`use_flash_attention_2`を追加します。
|
||||
モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`attn_implementation="flash_attention_2"`を追加します。
|
||||
|
||||
|
||||
```python
|
||||
@ -57,7 +57,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
```
|
||||
|
||||
@ -114,7 +114,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_8bit=True,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
```
|
||||
|
||||
@ -132,7 +132,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_4bit=True,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
```
|
||||
|
||||
@ -151,7 +151,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_4bit=True,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
|
@ -66,12 +66,12 @@ model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda:0")
|
||||
|
||||
### 结合 AWQ 和 Flash Attention
|
||||
|
||||
您可以将AWQ量化与Flash Attention结合起来,得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`use_flash_attention_2=True`参数。
|
||||
您可以将AWQ量化与Flash Attention结合起来,得到一个既被量化又更快速的模型。只需使用`from_pretrained`加载模型,并传递`attn_implementation="flash_attention_2"`参数。
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
|
||||
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
|
||||
```
|
||||
|
||||
### 基准测试
|
||||
|
@ -236,6 +236,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
This attribute is currently not being used during model loading time, but this may change in the future
|
||||
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
|
||||
attn_implementation (`str`, *optional*):
|
||||
The attention implementation to use in the model. Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
|
||||
|
||||
> TensorFlow specific parameters
|
||||
|
||||
@ -374,6 +376,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
# Config hash
|
||||
self._commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
# Attention implementation to use, if relevant.
|
||||
self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
|
||||
|
||||
# Drop the transformers version info
|
||||
self.transformers_version = kwargs.pop("transformers_version", None)
|
||||
|
||||
@ -422,6 +427,22 @@ class PretrainedConfig(PushToHubMixin):
|
||||
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
|
||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||
|
||||
@property
|
||||
def _attn_implementation(self):
|
||||
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
|
||||
if hasattr(self, "_attn_implementation_internal"):
|
||||
if self._attn_implementation_internal is None:
|
||||
# `config.attn_implementation` should never be None, for backward compatibility.
|
||||
return "eager"
|
||||
else:
|
||||
return self._attn_implementation_internal
|
||||
else:
|
||||
return "eager"
|
||||
|
||||
@_attn_implementation.setter
|
||||
def _attn_implementation(self, value):
|
||||
self._attn_implementation_internal = value
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
@ -747,6 +768,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
||||
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
||||
|
||||
# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
|
||||
config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
|
||||
|
||||
config = cls(**config_dict)
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
@ -861,8 +885,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||
|
||||
if "_flash_attn_2_enabled" in serializable_config_dict:
|
||||
del serializable_config_dict["_flash_attn_2_enabled"]
|
||||
if "_attn_implementation_internal" in serializable_config_dict:
|
||||
del serializable_config_dict["_attn_implementation_internal"]
|
||||
|
||||
return serializable_config_dict
|
||||
|
||||
@ -880,8 +904,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
del output["_auto_class"]
|
||||
if "_commit_hash" in output:
|
||||
del output["_commit_hash"]
|
||||
if "_flash_attn_2_enabled" in output:
|
||||
del output["_flash_attn_2_enabled"]
|
||||
if "_attn_implementation_internal" in output:
|
||||
del output["_attn_implementation_internal"]
|
||||
|
||||
# Transformers version when serializing the model
|
||||
output["transformers_version"] = __version__
|
||||
|
@ -68,7 +68,7 @@ class AttentionMaskConverter:
|
||||
key_value_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[torch.device, "str"] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
||||
bias to upper right hand triangular matrix (causal mask).
|
||||
@ -184,6 +184,95 @@ class AttentionMaskConverter:
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
@staticmethod
|
||||
def _unmask_unattended(
|
||||
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
|
||||
):
|
||||
# fmt: off
|
||||
"""
|
||||
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
||||
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
|
||||
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
||||
`attention_mask` is [bsz, src_seq_len].
|
||||
|
||||
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
|
||||
|
||||
For example, if `attention_mask` is
|
||||
```
|
||||
[[0, 0, 1],
|
||||
[1, 1, 1],
|
||||
[0, 1, 1]]
|
||||
```
|
||||
and `expanded_mask` is (e.g. here left-padding case)
|
||||
```
|
||||
[[[[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 1]]],
|
||||
[[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]],
|
||||
[[[0, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 1, 1]]]]
|
||||
```
|
||||
then the modified `expanded_mask` will be
|
||||
```
|
||||
[[[[1, 1, 1], <-- modified
|
||||
[1, 1, 1], <-- modified
|
||||
[0, 0, 1]]],
|
||||
[[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]],
|
||||
[[[1, 1, 1], <-- modified
|
||||
[0, 1, 0],
|
||||
[0, 1, 1]]]]
|
||||
```
|
||||
"""
|
||||
# fmt: on
|
||||
|
||||
# Get the index of the first non-zero value for every sample in the batch.
|
||||
# In the above example, indices = [[2], [0], [1]]]
|
||||
tmp = torch.arange(attention_mask.shape[1], 0, -1)
|
||||
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
|
||||
|
||||
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
|
||||
# expanded mask will be completely unattended.
|
||||
left_masked_rows = torch.where(indices > 0)[0]
|
||||
|
||||
if left_masked_rows.shape[0] == 0:
|
||||
return expanded_mask
|
||||
indices = indices[left_masked_rows]
|
||||
|
||||
max_len = torch.max(indices)
|
||||
range_tensor = torch.arange(max_len).unsqueeze(0)
|
||||
range_tensor = range_tensor.repeat(indices.size(0), 1)
|
||||
|
||||
# Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
|
||||
range_tensor[range_tensor >= indices] = 0
|
||||
|
||||
# TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
|
||||
if expanded_mask.dim() == 4:
|
||||
num_masks = expanded_mask.shape[1]
|
||||
if num_masks == 1:
|
||||
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
|
||||
mask_slice = (left_masked_rows[:, None], 0, range_tensor)
|
||||
else:
|
||||
# Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
|
||||
mask_slice = (
|
||||
left_masked_rows[:, None, None],
|
||||
torch.arange(num_masks)[None, :, None],
|
||||
range_tensor[:, None, :],
|
||||
)
|
||||
else:
|
||||
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
|
||||
mask_slice = (left_masked_rows[:, None], range_tensor)
|
||||
|
||||
expanded_mask[mask_slice] = unmasked_value
|
||||
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def _prepare_4d_causal_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
@ -225,6 +314,78 @@ def _prepare_4d_causal_attention_mask(
|
||||
return attention_mask
|
||||
|
||||
|
||||
# Adapted from _prepare_4d_causal_attention_mask
|
||||
def _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
input_shape: Union[torch.Size, Tuple, List],
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
|
||||
|
||||
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
|
||||
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
|
||||
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
||||
"""
|
||||
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
||||
|
||||
key_value_length = input_shape[-1] + past_key_values_length
|
||||
batch_size, query_length = input_shape
|
||||
|
||||
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
||||
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
||||
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
|
||||
is_tracing = torch.jit.is_tracing()
|
||||
|
||||
if attention_mask is not None:
|
||||
if torch.all(attention_mask == 1):
|
||||
if is_tracing:
|
||||
pass
|
||||
elif query_length == 1:
|
||||
# For query_length == 1, causal attention and bi-directional attention are the same.
|
||||
attention_mask = None
|
||||
elif key_value_length == query_length:
|
||||
attention_mask = None
|
||||
else:
|
||||
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
||||
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
||||
pass
|
||||
elif query_length > 1 and key_value_length != query_length:
|
||||
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
|
||||
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
|
||||
attention_mask = True
|
||||
elif is_tracing:
|
||||
raise ValueError(
|
||||
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
expanded_4d_mask = None
|
||||
elif attention_mask is True:
|
||||
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
||||
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
else:
|
||||
expanded_4d_mask = attn_mask_converter.to_4d(
|
||||
attention_mask,
|
||||
input_shape[-1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
key_value_length=key_value_length,
|
||||
)
|
||||
|
||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
if query_length > 1:
|
||||
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
||||
expanded_4d_mask, attention_mask, unmasked_value=0.0
|
||||
)
|
||||
|
||||
return expanded_4d_mask
|
||||
|
||||
|
||||
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
@ -241,13 +402,51 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len:
|
||||
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
||||
|
||||
|
||||
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`
|
||||
|
||||
Args:
|
||||
mask (`torch.Tensor` or `None`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)`
|
||||
dtype (`torch.dtype`):
|
||||
The torch dtype the created mask shall have.
|
||||
tgt_len (`int`):
|
||||
The target length or query length the created mask shall have.
|
||||
"""
|
||||
batch_size, key_value_length = mask.shape
|
||||
tgt_len = tgt_len if tgt_len is not None else key_value_length
|
||||
|
||||
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
||||
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
||||
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
|
||||
is_tracing = torch.jit.is_tracing()
|
||||
|
||||
if torch.all(mask == 1):
|
||||
if is_tracing:
|
||||
pass
|
||||
elif tgt_len == 1:
|
||||
# For query_length == 1, causal attention and bi-directional attention are the same.
|
||||
return None
|
||||
elif key_value_length == tgt_len:
|
||||
return None
|
||||
else:
|
||||
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
|
||||
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
||||
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
||||
else:
|
||||
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
||||
|
||||
|
||||
def _create_4d_causal_attention_mask(
|
||||
input_shape: Union[torch.Size, Tuple, List],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
|
||||
|
||||
|
@ -81,6 +81,7 @@ from .utils import (
|
||||
is_peft_available,
|
||||
is_remote_url,
|
||||
is_safetensors_available,
|
||||
is_torch_sdpa_available,
|
||||
is_torch_tpu_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@ -1128,6 +1129,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Flash Attention 2 support
|
||||
_supports_flash_attn_2 = False
|
||||
|
||||
# SDPA support
|
||||
_supports_sdpa = False
|
||||
|
||||
# Has support for a `Cache` instance as `past_key_values`
|
||||
_supports_cache_class = False
|
||||
|
||||
@ -1154,7 +1158,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
# Save config and origin of the pretrained weights if given in model
|
||||
config = self._autoset_attn_implementation(
|
||||
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.name_or_path = config.name_or_path
|
||||
self.warnings_issued = {}
|
||||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||
@ -1185,8 +1193,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
Args:
|
||||
torch_dtype (`torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
use_flash_attention_2 (`bool`, *optional*):
|
||||
Whether to load the model with Flash Attention 2 modules.
|
||||
"""
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
||||
@ -1196,8 +1202,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if torch_dtype is not None:
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
if use_flash_attention_2:
|
||||
config = cls._check_and_enable_flash_attn_2(config, torch_dtype)
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
|
||||
config._attn_implementation = kwargs.pop("attn_implementation", None)
|
||||
config = cls._autoset_attn_implementation(
|
||||
config, use_flash_attention_2=use_flash_attention_2, check_device_map=False
|
||||
)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
@ -1216,6 +1225,67 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _autoset_attn_implementation(
|
||||
cls,
|
||||
config,
|
||||
use_flash_attention_2: bool = False,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
||||
check_device_map: bool = True,
|
||||
):
|
||||
"""
|
||||
Automatically checks and dispatches to a default attention implementation. In order of priority:
|
||||
1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
|
||||
2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
|
||||
3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
|
||||
4. The default model's implementation otherwise (`LlamaAttention` for example) .
|
||||
"""
|
||||
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
|
||||
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
|
||||
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
|
||||
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
|
||||
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
|
||||
raise ValueError(
|
||||
f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
|
||||
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
|
||||
)
|
||||
|
||||
if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
|
||||
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
if cls._supports_flash_attn_2:
|
||||
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
if cls._supports_sdpa:
|
||||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
|
||||
raise ValueError(message + ".")
|
||||
|
||||
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
|
||||
hard_check_only = True
|
||||
else:
|
||||
hard_check_only = False
|
||||
|
||||
if use_flash_attention_2:
|
||||
logger.warning_once(
|
||||
'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
|
||||
)
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
if config._attn_implementation == "flash_attention_2":
|
||||
cls._check_and_enable_flash_attn_2(
|
||||
config,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
hard_check_only=hard_check_only,
|
||||
check_device_map=check_device_map,
|
||||
)
|
||||
elif cls._supports_sdpa or config._attn_implementation == "sdpa":
|
||||
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
||||
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
|
||||
elif not hard_check_only:
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
@ -1266,38 +1336,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_flash_attn_2(
|
||||
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
|
||||
cls,
|
||||
config,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
||||
check_device_map: bool = True,
|
||||
hard_check_only: bool = False,
|
||||
) -> PretrainedConfig:
|
||||
"""
|
||||
If you don't know about Flash Attention, check out the official repository of flash attention:
|
||||
https://github.com/Dao-AILab/flash-attention
|
||||
Checks the availability of Flash Attention 2 and compatibility with the current model.
|
||||
|
||||
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
|
||||
specific section of the documentation to learn more about it:
|
||||
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
|
||||
|
||||
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
|
||||
half precision and not ran on CPU.
|
||||
|
||||
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
|
||||
can initialize the correct attention module
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
|
||||
"""
|
||||
if not cls._supports_flash_attn_2:
|
||||
raise ValueError(
|
||||
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
|
||||
f"{cls.__name__} does not support Flash Attention 2.0 yet. Please open an issue on GitHub to "
|
||||
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
|
||||
)
|
||||
|
||||
if not is_flash_attn_2_available():
|
||||
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
||||
|
||||
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
|
||||
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
||||
if torch.version.cuda:
|
||||
if importlib.util.find_spec("flash_attn") is None:
|
||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||
|
||||
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
||||
if importlib.util.find_spec("flash_attn") is None:
|
||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||
|
||||
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
||||
if torch.version.cuda:
|
||||
if flash_attention_version < version.parse("2.1.0"):
|
||||
raise ImportError(
|
||||
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
|
||||
@ -1305,9 +1370,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
|
||||
elif torch.version.hip:
|
||||
if importlib.util.find_spec("flash_attn") is None:
|
||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
||||
if flash_attention_version < version.parse("2.0.4"):
|
||||
raise ImportError(
|
||||
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
|
||||
@ -1332,20 +1394,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
" unexpected behaviour."
|
||||
)
|
||||
|
||||
if device_map is None:
|
||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
|
||||
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
|
||||
if torch.cuda.is_available():
|
||||
logger.warning(
|
||||
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU"
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
" after initializing it on CPU with `model.to('cuda')`."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. "
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
|
||||
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
|
||||
"or initialising the model on CPU and then moving it to GPU."
|
||||
)
|
||||
elif (
|
||||
device_map is not None
|
||||
check_device_map
|
||||
and device_map is not None
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
@ -1353,7 +1418,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
|
||||
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
|
||||
)
|
||||
config._flash_attn_2_enabled = True
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
|
||||
"""
|
||||
Checks the availability of SDPA for a given model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
|
||||
"""
|
||||
if hard_check_only:
|
||||
if not cls._supports_sdpa:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please open an issue on GitHub to "
|
||||
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
|
||||
)
|
||||
if not is_torch_sdpa_available():
|
||||
raise ImportError(
|
||||
"PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
|
||||
)
|
||||
|
||||
if not is_torch_sdpa_available() or not cls._supports_sdpa:
|
||||
return config
|
||||
|
||||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
|
||||
if _is_bettertransformer:
|
||||
return config
|
||||
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "sdpa"
|
||||
return config
|
||||
|
||||
def enable_input_require_grads(self):
|
||||
@ -3312,8 +3407,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
|
||||
init_contexts.append(init_empty_weights())
|
||||
|
||||
if use_flash_attention_2:
|
||||
config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
||||
config = cls._autoset_attn_implementation(
|
||||
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
|
||||
)
|
||||
|
||||
with ContextManagers(init_contexts):
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
@ -389,7 +389,7 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
|
||||
|
||||
|
||||
BARK_ATTENTION_CLASSES = {
|
||||
"default": BarkSelfAttention,
|
||||
"eager": BarkSelfAttention,
|
||||
"flash_attention_2": BarkSelfFlashAttention2,
|
||||
}
|
||||
|
||||
@ -436,8 +436,7 @@ class BarkBlock(nn.Module):
|
||||
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
|
||||
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)
|
||||
self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal)
|
||||
|
||||
self.mlp = BarkMLP(config)
|
||||
|
||||
@ -670,6 +669,7 @@ class BarkCausalModel(BarkPreTrainedModel):
|
||||
self.drop = nn.Dropout(config.dropout)
|
||||
|
||||
self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)
|
||||
|
||||
@ -805,7 +805,7 @@ class BarkCausalModel(BarkPreTrainedModel):
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
else:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
@ -1265,6 +1265,7 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
self.drop = nn.Dropout(config.dropout)
|
||||
|
||||
self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
self.layernorm_final = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
@ -1434,7 +1435,7 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
else:
|
||||
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
|
||||
@ -1875,7 +1876,11 @@ class BarkModel(BarkPreTrainedModel):
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_flash_attn_2(
|
||||
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
|
||||
cls,
|
||||
config,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
||||
hard_check_only: bool = False,
|
||||
):
|
||||
"""
|
||||
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
|
||||
@ -1892,12 +1897,14 @@ class BarkModel(BarkPreTrainedModel):
|
||||
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
|
||||
half precision and not ran on CPU.
|
||||
|
||||
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model
|
||||
can initialize the correct attention module
|
||||
"""
|
||||
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map)
|
||||
config = super()._check_and_enable_flash_attn_2(
|
||||
config, torch_dtype, device_map, hard_check_only=hard_check_only
|
||||
)
|
||||
|
||||
config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
|
||||
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
|
||||
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
|
||||
config.semantic_config._attn_implementation = config._attn_implementation
|
||||
config.coarse_acoustics_config._attn_implementation = config._attn_implementation
|
||||
config.fine_acoustics_config._attn_implementation = config._attn_implementation
|
||||
return config
|
||||
|
@ -25,7 +25,12 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -505,8 +510,109 @@ class BartFlashAttention2(BartAttention):
|
||||
)
|
||||
|
||||
|
||||
class BartSdpaAttention(BartAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
query_states = self._shape(query_states, tgt_len, bsz)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
BART_ATTENTION_CLASSES = {
|
||||
"default": BartAttention,
|
||||
"eager": BartAttention,
|
||||
"sdpa": BartSdpaAttention,
|
||||
"flash_attention_2": BartFlashAttention2,
|
||||
}
|
||||
|
||||
@ -515,9 +621,8 @@ class BartEncoderLayer(nn.Module):
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = BART_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -587,8 +692,7 @@ class BartDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
self.self_attn = BART_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -601,7 +705,7 @@ class BartDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = BART_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -735,6 +839,7 @@ class BartPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -961,6 +1066,8 @@ class BartEncoder(BartPreTrainedModel):
|
||||
embed_dim,
|
||||
)
|
||||
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -1048,8 +1155,13 @@ class BartEncoder(BartPreTrainedModel):
|
||||
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
elif self._use_sdpa and head_mask is None and not output_attentions:
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||
@ -1136,6 +1248,9 @@ class BartDecoder(BartPreTrainedModel):
|
||||
config.d_model,
|
||||
)
|
||||
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -1254,9 +1369,18 @@ class BartDecoder(BartPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input) * self.embed_scale
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
|
||||
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
@ -1265,8 +1389,17 @@ class BartDecoder(BartPreTrainedModel):
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
|
||||
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask,
|
||||
inputs_embeds.dtype,
|
||||
tgt_len=input_shape[-1],
|
||||
)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
|
@ -252,7 +252,7 @@ class BlenderbotAttention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
BLENDERBOT_ATTENTION_CLASSES = {"default": BlenderbotAttention}
|
||||
BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention}
|
||||
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
|
||||
@ -260,9 +260,8 @@ class BlenderbotEncoderLayer(nn.Module):
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -332,9 +331,8 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -347,7 +345,7 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
@ -254,9 +254,8 @@ class BlenderbotSmallEncoderLayer(nn.Module):
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -321,7 +320,10 @@ class BlenderbotSmallEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
BLENDERBOT_SMALL_ATTENTION_CLASSES = {"default": BlenderbotSmallAttention}
|
||||
# TODO: Implement attention with SDPA for TimeSeriesTransformer.
|
||||
BLENDERBOT_SMALL_ATTENTION_CLASSES = {
|
||||
"eager": BlenderbotSmallAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
|
||||
@ -330,8 +332,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -344,7 +345,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
@ -471,6 +471,12 @@ class FFN(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
DISTILBERT_ATTENTION_CLASSES = {
|
||||
"eager": MultiHeadSelfAttention,
|
||||
"flash_attention_2": DistilBertFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
@ -479,11 +485,7 @@ class TransformerBlock(nn.Module):
|
||||
if config.dim % config.n_heads != 0:
|
||||
raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
|
||||
|
||||
self.attention = (
|
||||
MultiHeadSelfAttention(config)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else DistilBertFlashAttention2(config)
|
||||
)
|
||||
self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
|
||||
|
||||
self.ffn = FFN(config)
|
||||
@ -703,6 +705,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
|
||||
self.embeddings = Embeddings(config) # Embeddings
|
||||
self.transformer = Transformer(config) # Encoder
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -808,7 +811,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
|
||||
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
if attention_mask is None:
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -24,7 +24,11 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -33,6 +37,7 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -44,6 +49,9 @@ from ...utils import (
|
||||
from .configuration_falcon import FalconConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
@ -278,6 +286,7 @@ class FalconAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
|
||||
if self.head_dim * self.num_heads != self.hidden_size:
|
||||
raise ValueError(
|
||||
@ -439,16 +448,15 @@ class FalconAttention(nn.Module):
|
||||
present = None
|
||||
|
||||
if alibi is None:
|
||||
if hasattr(F, "scaled_dot_product_attention") and not output_attentions:
|
||||
# TODO: deprecate this once we add FA2 support in Falcon
|
||||
logger.warning_once(
|
||||
"The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the"
|
||||
" future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call "
|
||||
"`model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations."
|
||||
)
|
||||
|
||||
if self._use_sdpa and not output_attentions:
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
0.0,
|
||||
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and query_length > 1,
|
||||
)
|
||||
attention_scores = None
|
||||
else:
|
||||
@ -456,58 +464,70 @@ class FalconAttention(nn.Module):
|
||||
attention_scores /= math.sqrt(self.head_dim)
|
||||
|
||||
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
|
||||
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
|
||||
attn_output = attention_scores @ value_layer
|
||||
|
||||
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
||||
|
||||
output_tensor = self.dense(attn_output)
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
if output_attentions:
|
||||
return output_tensor, present, attention_scores
|
||||
return attn_output, present, attention_scores
|
||||
else:
|
||||
return output_tensor, present
|
||||
return attn_output, present
|
||||
|
||||
else:
|
||||
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
||||
if self._use_sdpa and not output_attentions and head_mask is None:
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
||||
is_causal=self.is_causal and attention_mask is None and query_length > 1,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
attn_output = self.dense(attn_output)
|
||||
else:
|
||||
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
||||
attention_scores = attention_scores.to(torch.float32)
|
||||
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
|
||||
# adding (alibi * self.inv_norm_factor) to attention_mask. I think this would be mathematically
|
||||
# equivalent and more performant, but there might be a numerical difference. If you're reading this
|
||||
# and you'd like to experiment and maybe file a PR, feel free!
|
||||
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
||||
attention_logits *= self.inv_norm_factor
|
||||
attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
||||
attention_scores = attention_scores.to(torch.float32)
|
||||
|
||||
# change view [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
||||
attention_logits *= self.inv_norm_factor
|
||||
attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = (attention_probs_reshaped @ value_layer).flatten(0, 1)
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size, q_length, num_heads * head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
# change view [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
|
||||
output_tensor = self.dense(context_layer)
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)
|
||||
|
||||
# change view [batch_size, q_length, num_heads * head_dim]
|
||||
attn_output = self._merge_heads(attn_output)
|
||||
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
if output_attentions:
|
||||
return output_tensor, present, attention_probs
|
||||
return attn_output, present, attention_probs
|
||||
else:
|
||||
return output_tensor, present
|
||||
return attn_output, present
|
||||
|
||||
|
||||
class FalconFlashAttention2(FalconAttention):
|
||||
@ -734,17 +754,20 @@ class FalconMLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
FALCON_ATTENTION_CLASSES = {
|
||||
"eager": FalconAttention,
|
||||
"sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
|
||||
"flash_attention_2": FalconFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class FalconDecoderLayer(nn.Module):
|
||||
def __init__(self, config: FalconConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.self_attention = (
|
||||
FalconAttention(config)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else FalconFlashAttention2(config)
|
||||
)
|
||||
self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.mlp = FalconMLP(config)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
self.config = config
|
||||
@ -912,6 +935,7 @@ class FalconPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["FalconDecoderLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
@ -932,6 +956,25 @@ class FalconPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
|
||||
@classmethod
|
||||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
|
||||
# NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0).
|
||||
if hard_check_only:
|
||||
if not is_torch_greater_or_equal_than_2_0:
|
||||
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")
|
||||
|
||||
if not is_torch_greater_or_equal_than_2_0:
|
||||
return config
|
||||
|
||||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
|
||||
if _is_bettertransformer:
|
||||
return config
|
||||
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "sdpa"
|
||||
return config
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
@ -950,6 +993,8 @@ class FalconModel(FalconPreTrainedModel):
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -1003,12 +1048,6 @@ class FalconModel(FalconPreTrainedModel):
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
@ -1047,15 +1086,61 @@ class FalconModel(FalconPreTrainedModel):
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
if alibi is None:
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
elif head_mask is None:
|
||||
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
|
||||
|
||||
attention_mask_2d = attention_mask
|
||||
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# We take care to integrate alibi bias in the attention_mask here.
|
||||
if attention_mask_2d is None:
|
||||
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
||||
else:
|
||||
attention_mask = torch.masked_fill(
|
||||
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
||||
attention_mask < -1,
|
||||
torch.finfo(alibi.dtype).min,
|
||||
)
|
||||
|
||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
if seq_length > 1:
|
||||
attention_mask = AttentionMaskConverter._unmask_unattended(
|
||||
attention_mask, attention_mask_2d, unmasked_value=0.0
|
||||
)
|
||||
else:
|
||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
@ -22,6 +22,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -128,6 +129,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
self.scale_attention_softmax_in_fp32 = (
|
||||
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
|
||||
)
|
||||
self.attn_pdrop = config.attn_pdrop
|
||||
|
||||
if self.is_cross_attention:
|
||||
if self.multi_query:
|
||||
@ -359,7 +361,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
||||
key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
|
||||
value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
|
||||
|
||||
attn_dropout = self.config.attn_pdrop if self.training else 0.0
|
||||
attn_dropout = self.attn_pdrop if self.training else 0.0
|
||||
|
||||
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
|
||||
upcast = query.dtype != softmax_dtype
|
||||
@ -509,6 +511,137 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
||||
)
|
||||
|
||||
|
||||
class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
if head_mask is not None:
|
||||
# The super dispatch is done in the forward.
|
||||
raise ValueError(
|
||||
"PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository."
|
||||
)
|
||||
|
||||
scale = None
|
||||
if not self.scale_attn_weights:
|
||||
scale = 1
|
||||
|
||||
# MQA models: (batch_size, query_length, num_heads * head_dim)
|
||||
# MHA models: (batch_size, num_heads, query_length, head_dim)
|
||||
query_shape = query.shape
|
||||
batch_size = query_shape[0]
|
||||
key.shape[-2]
|
||||
|
||||
if self.multi_query:
|
||||
query_length = query_shape[1]
|
||||
|
||||
# NOTE: Maybe there is better than this?
|
||||
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
|
||||
key = key.unsqueeze(1)
|
||||
value = value.unsqueeze(1)
|
||||
|
||||
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention
|
||||
# and flash attention (No available kernel. Aborting execution.) from the shapes
|
||||
# query = [batch_size, num_heads, query_length, head_dim]
|
||||
# key = [batch_size, 1, past_length, head_dim]
|
||||
# value = [batch_size, 1, past_length, head_dim]
|
||||
# which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy.
|
||||
key = key.expand(-1, self.num_heads, -1, -1)
|
||||
value = value.expand(-1, self.num_heads, -1, -1)
|
||||
else:
|
||||
query_length = query_shape[-1]
|
||||
|
||||
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attn_pdrop if self.training else 0.0,
|
||||
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and query_length > 1,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
if self.multi_query:
|
||||
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
|
||||
sdpa_result = sdpa_result.transpose(1, 2)
|
||||
|
||||
# Reshape is kind of expensive here, as it does a memory copy,
|
||||
# but I did not manage to make away without it (logits do not match when using view)
|
||||
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
|
||||
sdpa_result = sdpa_result.reshape(query_shape)
|
||||
|
||||
return sdpa_result, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
]:
|
||||
if encoder_hidden_states is not None:
|
||||
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
|
||||
)
|
||||
|
||||
query = self.q_attn(hidden_states)
|
||||
key_value = self.c_attn(encoder_hidden_states)
|
||||
attention_mask = encoder_attention_mask
|
||||
elif self.multi_query:
|
||||
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
|
||||
else:
|
||||
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
|
||||
# i.e., the memory layout is not the same as GPT2.
|
||||
# This makes the concatenation with past_key_value more efficient.
|
||||
query, key_value = (
|
||||
self.c_attn(hidden_states)
|
||||
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
||||
)
|
||||
|
||||
if layer_past is not None:
|
||||
key_value = torch.cat((layer_past, key_value), dim=-2)
|
||||
present = key_value if use_cache else None
|
||||
|
||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||
|
||||
if not output_attentions and head_mask is None:
|
||||
# Difference with the original implementation: there is no need to transpose the key here,
|
||||
# as SDPA expects seq_length to be at index -2 for the key as well
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
else:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None."
|
||||
' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
|
||||
|
||||
if not self.multi_query:
|
||||
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
if self.multi_query:
|
||||
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
|
||||
attn_weights = attn_weights.transpose(1, 2)
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class GPTBigCodeMLP(nn.Module):
|
||||
def __init__(self, intermediate_size, config):
|
||||
super().__init__()
|
||||
@ -527,6 +660,13 @@ class GPTBigCodeMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
GPTBIGCODE_ATTENTION_CLASSES = {
|
||||
"eager": GPTBigCodeAttention,
|
||||
"flash_attention_2": GPTBigCodeFlashAttention2,
|
||||
"sdpa": GPTBigCodeSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class GPTBigCodeBlock(nn.Module):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
@ -534,21 +674,19 @@ class GPTBigCodeBlock(nn.Module):
|
||||
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = (
|
||||
GPTBigCodeAttention(config, layer_idx=layer_idx)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx)
|
||||
)
|
||||
|
||||
self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
||||
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
if config.add_cross_attention:
|
||||
if config.multi_query:
|
||||
raise NotImplementedError("Cross-attention not implemented for MQA")
|
||||
self.crossattention = (
|
||||
GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx)
|
||||
|
||||
self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, is_cross_attention=True, layer_idx=layer_idx
|
||||
)
|
||||
|
||||
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = GPTBigCodeMLP(self.inner_dim, config)
|
||||
@ -629,6 +767,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["GPTBigCodeBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
@ -770,6 +909,9 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -850,7 +992,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
key_length = past_length + query_length
|
||||
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
|
||||
encoder_attention_mask = (
|
||||
@ -867,7 +1009,34 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
|
||||
# MQA models: (batch_size, query_length, n_heads, key_length)
|
||||
# MHA models: (batch_size, n_heads, query_length, key_length)
|
||||
attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
|
||||
self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
|
||||
|
||||
if self._use_sdpa and head_mask is None and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
if self.multi_query:
|
||||
# gpt_bigcode using MQA has the bad taste to use a causal mask with shape
|
||||
# [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
|
||||
self_attention_mask = self_attention_mask.transpose(1, 2)
|
||||
|
||||
if query_length > 1 and attention_mask is not None:
|
||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
self_attention_mask = AttentionMaskConverter._unmask_unattended(
|
||||
self_attention_mask, attention_mask, unmasked_value=True
|
||||
)
|
||||
|
||||
# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
|
||||
dtype = self.wte.weight.dtype
|
||||
self_attention_mask = torch.where(
|
||||
self_attention_mask,
|
||||
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
|
||||
torch.full(
|
||||
[], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = self_attention_mask
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
@ -487,6 +487,12 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
||||
)
|
||||
|
||||
|
||||
GPT_NEO_ATTENTION_CLASSES = {
|
||||
"eager": GPTNeoSelfAttention,
|
||||
"flash_attention_2": GPTNeoFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class GPTNeoAttention(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
@ -495,11 +501,7 @@ class GPTNeoAttention(nn.Module):
|
||||
self.attention_type = self.attention_layers[layer_id]
|
||||
|
||||
if self.attention_type in ["global", "local"]:
|
||||
self.attention = (
|
||||
GPTNeoSelfAttention(config, self.attention_type)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else GPTNeoFlashAttention2(config, self.attention_type)
|
||||
)
|
||||
self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](config, self.attention_type)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
|
||||
@ -718,6 +720,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.drop = nn.Dropout(float(config.embed_dropout))
|
||||
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -795,7 +798,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
# Attention mask.
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
|
@ -658,6 +658,12 @@ class GPTNeoXMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
GPT_NEOX_ATTENTION_CLASSES = {
|
||||
"eager": GPTNeoXAttention,
|
||||
"flash_attention_2": GPTNeoXFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class GPTNeoXLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@ -666,11 +672,7 @@ class GPTNeoXLayer(nn.Module):
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.attention = (
|
||||
GPTNeoXAttention(config)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else GPTNeoXFlashAttention2(config)
|
||||
)
|
||||
self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.mlp = GPTNeoXMLP(config)
|
||||
|
||||
def forward(
|
||||
@ -785,6 +787,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
self.emb_dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@ -861,7 +864,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
if attention_mask is not None:
|
||||
assert batch_size > 0, "batch_size has to be defined and > 0"
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
else:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
|
@ -29,7 +29,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PretrainedConfig
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
@ -578,6 +578,7 @@ class IdeficsAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.dropout = dropout
|
||||
self.is_causal = True
|
||||
|
||||
if (self.head_dim * num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
@ -693,6 +694,8 @@ class IdeficsAttention(nn.Module):
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
@ -960,6 +963,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Idefics isn't meant for training from scratch - only
|
||||
@ -975,6 +979,18 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
|
||||
@classmethod
|
||||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
|
||||
# We remove the checks on `is_torch_sdpa_available()` and `cls._supports_sdpa` as Falcon supports SDPA from torch==2.0.0 (no requirement on 2.1).
|
||||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
|
||||
if _is_bettertransformer:
|
||||
return config
|
||||
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "sdpa"
|
||||
return config
|
||||
|
||||
|
||||
LLAMA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
@ -1240,7 +1256,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
|
@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
@ -518,7 +519,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
@ -654,15 +655,99 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
)
|
||||
|
||||
|
||||
class LlamaSdpaAttention(LlamaAttention):
|
||||
"""
|
||||
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from LlamaAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
LLAMA_ATTENTION_CLASSES = {
|
||||
"eager": LlamaAttention,
|
||||
"flash_attention_2": LlamaFlashAttention2,
|
||||
"sdpa": LlamaSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = (
|
||||
LlamaAttention(config=config, layer_idx=layer_idx)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else LlamaFlashAttention2(config=config, layer_idx=layer_idx)
|
||||
)
|
||||
|
||||
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = LlamaMLP(config)
|
||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -757,6 +842,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["LlamaDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
@ -862,6 +948,8 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
self.layers = nn.ModuleList(
|
||||
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -922,9 +1010,18 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
|
@ -232,9 +232,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
use_flash_attention_2 = getattr(config, "_flash_attn_2_enabled", False)
|
||||
self.language_model = AutoModelForCausalLM.from_config(
|
||||
config.text_config, use_flash_attention_2=use_flash_attention_2
|
||||
config.text_config, attn_implementation=config._attn_implementation
|
||||
)
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
self.post_init()
|
||||
|
@ -325,9 +325,8 @@ class M2M100EncoderLayer(nn.Module):
|
||||
def __init__(self, config: M2M100Config):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -392,7 +391,7 @@ class M2M100EncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
M2M100_ATTENTION_CLASSES = {"default": M2M100Attention}
|
||||
M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention}
|
||||
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
|
||||
@ -400,9 +399,8 @@ class M2M100DecoderLayer(nn.Module):
|
||||
def __init__(self, config: M2M100Config):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = M2M100_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -415,7 +413,7 @@ class M2M100DecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = M2M100_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
@ -272,9 +272,8 @@ class MarianEncoderLayer(nn.Module):
|
||||
def __init__(self, config: MarianConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -339,7 +338,7 @@ class MarianEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
MARIAN_ATTENTION_CLASSES = {"default": MarianAttention}
|
||||
MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention}
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN
|
||||
@ -348,8 +347,7 @@ class MarianDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
self.self_attn = MARIAN_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -362,7 +360,7 @@ class MarianDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = MARIAN_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
@ -501,7 +501,7 @@ class MBartFlashAttention2(MBartAttention):
|
||||
|
||||
|
||||
MBART_ATTENTION_CLASSES = {
|
||||
"default": MBartAttention,
|
||||
"eager": MBartAttention,
|
||||
"flash_attention_2": MBartFlashAttention2,
|
||||
}
|
||||
|
||||
@ -510,9 +510,8 @@ class MBartEncoderLayer(nn.Module):
|
||||
def __init__(self, config: MBartConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -581,9 +580,8 @@ class MBartDecoderLayer(nn.Module):
|
||||
def __init__(self, config: MBartConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = MBART_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -596,7 +594,7 @@ class MBartDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = MBART_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -935,6 +933,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
||||
embed_dim,
|
||||
)
|
||||
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
@ -1023,7 +1022,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
@ -1112,6 +1111,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
config.d_model,
|
||||
)
|
||||
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
@ -1231,7 +1231,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
@ -1242,7 +1242,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
@ -601,15 +601,19 @@ class MistralFlashAttention2(MistralAttention):
|
||||
)
|
||||
|
||||
|
||||
MISTRAL_ATTENTION_CLASSES = {
|
||||
"eager": MistralAttention,
|
||||
"flash_attention_2": MistralFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class MistralDecoderLayer(nn.Module):
|
||||
def __init__(self, config: MistralConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = (
|
||||
MistralAttention(config=config, layer_idx=layer_idx)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else MistralFlashAttention2(config, layer_idx=layer_idx)
|
||||
)
|
||||
|
||||
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||
|
||||
self.mlp = MistralMLP(config)
|
||||
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -807,6 +811,7 @@ class MistralModel(MistralPreTrainedModel):
|
||||
self.layers = nn.ModuleList(
|
||||
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -870,12 +875,7 @@ class MistralModel(MistralPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if (
|
||||
attention_mask is not None
|
||||
and hasattr(self.config, "_flash_attn_2_enabled")
|
||||
and self.config._flash_attn_2_enabled
|
||||
and use_cache
|
||||
):
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
@ -884,7 +884,7 @@ class MistralModel(MistralPreTrainedModel):
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
|
@ -491,15 +491,18 @@ class OptFlashAttention2(OPTAttention):
|
||||
)
|
||||
|
||||
|
||||
OPT_ATTENTION_CLASSES = {
|
||||
"eager": OPTAttention,
|
||||
"flash_attention_2": OptFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
def __init__(self, config: OPTConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
if not getattr(config, "_flash_attn_2_enabled", False):
|
||||
self.self_attn = OPTAttention(config=config, is_decoder=True)
|
||||
else:
|
||||
self.self_attn = OptFlashAttention2(config=config, is_decoder=True)
|
||||
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True)
|
||||
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.dropout = config.dropout
|
||||
@ -732,6 +735,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -830,7 +834,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
|
||||
# embed positions
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
attention_mask = (
|
||||
|
@ -267,7 +267,7 @@ class PegasusAttention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
PEGASUS_ATTENTION_CLASSES = {"default": PegasusAttention}
|
||||
PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention}
|
||||
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
|
||||
@ -275,9 +275,8 @@ class PegasusEncoderLayer(nn.Module):
|
||||
def __init__(self, config: PegasusConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -347,9 +346,8 @@ class PegasusDecoderLayer(nn.Module):
|
||||
def __init__(self, config: PegasusConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -362,7 +360,7 @@ class PegasusDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = PEGASUS_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
@ -612,14 +612,16 @@ class PhiFlashAttention2(PhiAttention):
|
||||
)
|
||||
|
||||
|
||||
PHI_ATTENTION_CLASSES = {
|
||||
"eager": PhiAttention,
|
||||
"flash_attention_2": PhiFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class PhiDecoderLayer(nn.Module):
|
||||
def __init__(self, config: PhiConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = (
|
||||
PhiAttention(config=config, layer_idx=layer_idx)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else PhiFlashAttention2(config=config, layer_idx=layer_idx)
|
||||
)
|
||||
self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
||||
self.mlp = PhiMLP(config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
@ -813,6 +815,7 @@ class PhiModel(PhiPreTrainedModel):
|
||||
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -876,7 +879,7 @@ class PhiModel(PhiPreTrainedModel):
|
||||
inputs_embeds = self.embed_dropout(inputs_embeds)
|
||||
|
||||
# Attention mask.
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
|
@ -23,7 +23,12 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -265,9 +270,8 @@ class PLBartEncoderLayer(nn.Module):
|
||||
def __init__(self, config: PLBartConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -332,7 +336,8 @@ class PLBartEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
PLBART_ATTENTION_CLASSES = {"default": PLBartAttention}
|
||||
# TODO: Implement attention with SDPA for PLBart.
|
||||
PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention}
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART
|
||||
@ -341,8 +346,7 @@ class PLBartDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
self.self_attn = PLBART_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -355,7 +359,7 @@ class PLBartDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = PLBART_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -670,6 +674,8 @@ class PLBartEncoder(PLBartPreTrainedModel):
|
||||
embed_dim,
|
||||
)
|
||||
self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -757,8 +763,13 @@ class PLBartEncoder(PLBartPreTrainedModel):
|
||||
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
elif self._use_sdpa and head_mask is None and not output_attentions:
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||
@ -846,6 +857,9 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
||||
config.d_model,
|
||||
)
|
||||
self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -964,9 +978,18 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input) * self.embed_scale
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
|
||||
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
@ -975,8 +998,17 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
|
||||
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask,
|
||||
inputs_embeds.dtype,
|
||||
tgt_len=input_shape[-1],
|
||||
)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
|
@ -30,7 +30,12 @@ from ...modeling_outputs import (
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_speech_to_text import Speech2TextConfig
|
||||
|
||||
|
||||
@ -326,7 +331,7 @@ class Speech2TextAttention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
SPEECH_TO_TEXT_ATTENTION_CLASSES = {"default": Speech2TextAttention}
|
||||
SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention}
|
||||
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT
|
||||
@ -334,9 +339,8 @@ class Speech2TextEncoderLayer(nn.Module):
|
||||
def __init__(self, config: Speech2TextConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -406,9 +410,8 @@ class Speech2TextDecoderLayer(nn.Module):
|
||||
def __init__(self, config: Speech2TextConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -421,7 +424,7 @@ class Speech2TextDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
@ -32,7 +32,12 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_time_series_transformer import TimeSeriesTransformerConfig
|
||||
|
||||
|
||||
@ -436,9 +441,8 @@ class TimeSeriesTransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: TimeSeriesTransformerConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -503,7 +507,10 @@ class TimeSeriesTransformerEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {"default": TimeSeriesTransformerAttention}
|
||||
# TODO: Implement attention with SDPA for TimeSeriesTransformer.
|
||||
TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {
|
||||
"eager": TimeSeriesTransformerAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER
|
||||
@ -512,8 +519,7 @@ class TimeSeriesTransformerDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -526,7 +532,7 @@ class TimeSeriesTransformerDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -690,9 +690,111 @@ class WhisperFlashAttention2(WhisperAttention):
|
||||
)
|
||||
|
||||
|
||||
class WhisperSdpaAttention(WhisperAttention):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with BART->whisper, Bart->Whisper
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
query_states = self._shape(query_states, tgt_len, bsz)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
WHISPER_ATTENTION_CLASSES = {
|
||||
"default": WhisperAttention,
|
||||
"eager": WhisperAttention,
|
||||
"flash_attention_2": WhisperFlashAttention2,
|
||||
"sdpa": WhisperSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
@ -701,9 +803,8 @@ class WhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, config: WhisperConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -773,9 +874,8 @@ class WhisperDecoderLayer(nn.Module):
|
||||
def __init__(self, config: WhisperConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
|
||||
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[attn_type](
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -788,7 +888,7 @@ class WhisperDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = WHISPER_ATTENTION_CLASSES[attn_type](
|
||||
self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@ -897,6 +997,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -1227,6 +1328,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
|
||||
|
||||
self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
@ -1336,9 +1439,14 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and head_mask is None and not output_attentions:
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
|
@ -107,6 +107,7 @@ from .utils import (
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_sdpa_available,
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
@ -440,6 +441,15 @@ def require_flash_attn(test_case):
|
||||
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
|
||||
|
||||
|
||||
def require_torch_sdpa(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch's SDPA.
|
||||
|
||||
These tests are skipped when requirements are not met (torch version).
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
|
||||
|
||||
|
||||
def require_peft(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT.
|
||||
|
@ -180,6 +180,7 @@ from .import_utils import (
|
||||
is_torch_mps_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_sdpa_available,
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
|
@ -258,6 +258,19 @@ def get_torch_version():
|
||||
return _torch_version
|
||||
|
||||
|
||||
def is_torch_sdpa_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
elif _torch_version == "N/A":
|
||||
return False
|
||||
|
||||
# NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
|
||||
# - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
|
||||
# - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
|
||||
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
|
||||
return version.parse(_torch_version) >= version.parse("2.1.1")
|
||||
|
||||
|
||||
def is_torchvision_available():
|
||||
return _torchvision_available
|
||||
|
||||
|
@ -890,13 +890,11 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict["input_ids"][:1]
|
||||
@ -949,12 +947,13 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
tmpdirname,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
|
@ -319,13 +319,11 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||
@ -373,12 +371,13 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
tmpdirname,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
""" Testing suite for the PyTorch Falcon model. """
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
@ -26,7 +27,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_sdpa, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -437,6 +438,76 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
max_new_tokens = 30
|
||||
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
# NOTE: This check is disabled for Falcon as the non-SDPA/SDPA implementation is in the same class (legacy reason).
|
||||
# for name, submodule in model_eager.named_modules():
|
||||
# if "SdpaAttention" in submodule.__class__.__name__:
|
||||
# raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
# has_sdpa = False
|
||||
# for name, submodule in model_sdpa.named_modules():
|
||||
# if "SdpaAttention" in submodule.__class__.__name__:
|
||||
# has_sdpa = True
|
||||
# break
|
||||
# if not has_sdpa:
|
||||
# raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
# Just test that a large cache works as expected
|
||||
res_eager = model_eager.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
||||
)
|
||||
|
||||
res_sdpa = model_sdpa.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
|
||||
@require_torch
|
||||
class FalconLanguageGenerationTest(unittest.TestCase):
|
||||
|
@ -16,11 +16,14 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
require_torch_sdpa,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
@ -309,6 +312,12 @@ class IdeficsModelTester:
|
||||
def prepare_pixel_values(self):
|
||||
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
|
||||
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
||||
@require_torch
|
||||
@ -557,6 +566,12 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
model = IdeficsModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
self.skipTest("Idefics has a hard requirement on SDPA, skipping this test")
|
||||
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
||||
@require_torch
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch LLaMA model. """
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@ -26,6 +27,7 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -411,7 +413,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
output_native = tokenizer.batch_decode(output_native)
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
|
||||
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
@ -419,6 +421,85 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
self.assertListEqual(output_native, output_fa_2)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_use_flash_attention_2_true(self):
|
||||
"""
|
||||
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = model_class(config)
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
new_model = LlamaForCausalLM.from_pretrained(
|
||||
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
|
||||
|
||||
has_flash = False
|
||||
for name, submodule in new_model.named_modules():
|
||||
if "FlashAttention" in submodule.__class__.__name__:
|
||||
has_flash = True
|
||||
break
|
||||
if not has_flash:
|
||||
raise ValueError("The flash model should have flash attention layers")
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
max_new_tokens = 30
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
model_sdpa = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
model_eager = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa:
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"]
|
||||
|
||||
for padding_side in ["left", "right"]:
|
||||
tokenizer.padding_side = padding_side
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||
|
||||
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlamaIntegrationTest(unittest.TestCase):
|
||||
|
@ -387,9 +387,9 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
||||
@ -397,7 +397,10 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
@ -437,7 +440,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
@ -507,7 +510,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
device_map="auto",
|
||||
load_in_4bit=True,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||
|
@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
output_native = tokenizer.batch_decode(output_native)
|
||||
|
||||
model = PhiForCausalLM.from_pretrained(
|
||||
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
|
||||
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
|
@ -891,12 +891,13 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
tmpdirname,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
@ -936,11 +937,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
|
||||
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
@ -981,6 +982,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
configs_no_init._attn_implementation = "eager"
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
@ -2337,13 +2339,20 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)[0]
|
||||
|
||||
input_ids = inputs["input_features"]
|
||||
del inputs["input_features"]
|
||||
|
||||
encoder = model.encoder
|
||||
|
||||
encoder_inputs = {"input_features": inputs["input_features"]}
|
||||
del inputs["input_features"]
|
||||
|
||||
if "head_mask" in inputs:
|
||||
encoder_inputs["head_mask"] = inputs["head_mask"]
|
||||
if "attention_mask" in inputs:
|
||||
encoder_inputs["attention_mask"] = inputs["attention_mask"]
|
||||
if "output_attentions" in inputs:
|
||||
encoder_inputs["output_attentions"] = inputs["output_attentions"]
|
||||
|
||||
with torch.no_grad():
|
||||
inputs["encoder_outputs"] = encoder(input_ids)
|
||||
inputs["encoder_outputs"] = encoder(**encoder_inputs)
|
||||
outputs_embeds = model(**inputs)[0]
|
||||
|
||||
self.assertTrue((outputs_embeds == outputs).all())
|
||||
|
@ -198,7 +198,14 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
||||
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
||||
self.assertListEqual(
|
||||
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
|
||||
missing_keys,
|
||||
[
|
||||
"is_encoder_decoder",
|
||||
"_name_or_path",
|
||||
"_commit_hash",
|
||||
"_attn_implementation_internal",
|
||||
"transformers_version",
|
||||
],
|
||||
)
|
||||
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
||||
if len(keys_with_defaults) > 0:
|
||||
|
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import gc
|
||||
@ -28,6 +27,7 @@ from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
import transformers
|
||||
@ -71,6 +71,7 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -776,102 +777,120 @@ class ModelTesterMixin:
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
for attn_implementation in ["eager", "sdpa"]:
|
||||
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
|
||||
continue
|
||||
|
||||
main_input_name = model_class.main_input_name
|
||||
configs_no_init._attn_implementation = attn_implementation
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
main_input = inputs[main_input_name]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
traced_model = torch.jit.trace(
|
||||
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
|
||||
input_ids = inputs["input_ids"]
|
||||
bbox = inputs["bbox"]
|
||||
image = inputs["image"].tensor
|
||||
model(input_ids, bbox, image)
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, bbox, image), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
|
||||
input_ids = inputs["input_ids"]
|
||||
bbox = inputs["bbox"]
|
||||
model(input_ids, bbox)
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, bbox), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
else:
|
||||
main_input = inputs[main_input_name]
|
||||
model(main_input)
|
||||
traced_model = torch.jit.trace(model, main_input)
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
main_input_name = model_class.main_input_name
|
||||
|
||||
try:
|
||||
torch.jit.save(traced_model, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
main_input = inputs[main_input_name]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
traced_model = torch.jit.trace(
|
||||
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
|
||||
input_ids = inputs["input_ids"]
|
||||
bbox = inputs["bbox"]
|
||||
image = inputs["image"].tensor
|
||||
model(input_ids, bbox, image)
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, bbox, image), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
|
||||
input_ids = inputs["input_ids"]
|
||||
bbox = inputs["bbox"]
|
||||
model(input_ids, bbox)
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, bbox), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
else:
|
||||
main_input = inputs[main_input_name]
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
if model.config._attn_implementation == "sdpa":
|
||||
trace_input = {main_input_name: main_input}
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
if "attention_mask" in inputs:
|
||||
trace_input["attention_mask"] = inputs["attention_mask"]
|
||||
else:
|
||||
self.skipTest("testing SDPA without attention_mask is not supported")
|
||||
|
||||
loaded_model.to(torch_device)
|
||||
loaded_model.eval()
|
||||
model(main_input, attention_mask=inputs["attention_mask"])
|
||||
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
|
||||
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
|
||||
else:
|
||||
model(main_input)
|
||||
traced_model = torch.jit.trace(model, (main_input,))
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_model_state_dict = loaded_model.state_dict()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
non_persistent_buffers = {}
|
||||
for key in loaded_model_state_dict.keys():
|
||||
if key not in model_state_dict.keys():
|
||||
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||
try:
|
||||
torch.jit.save(traced_model, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
loaded_model_state_dict = {
|
||||
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||
}
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
|
||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model_buffers = list(model.buffers())
|
||||
for non_persistent_buffer in non_persistent_buffers.values():
|
||||
found_buffer = False
|
||||
for i, model_buffer in enumerate(model_buffers):
|
||||
if torch.equal(non_persistent_buffer, model_buffer):
|
||||
found_buffer = True
|
||||
break
|
||||
loaded_model.to(torch_device)
|
||||
loaded_model.eval()
|
||||
|
||||
self.assertTrue(found_buffer)
|
||||
model_buffers.pop(i)
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_model_state_dict = loaded_model.state_dict()
|
||||
|
||||
models_equal = True
|
||||
for layer_name, p1 in model_state_dict.items():
|
||||
if layer_name in loaded_model_state_dict:
|
||||
p2 = loaded_model_state_dict[layer_name]
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
non_persistent_buffers = {}
|
||||
for key in loaded_model_state_dict.keys():
|
||||
if key not in model_state_dict.keys():
|
||||
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
loaded_model_state_dict = {
|
||||
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||
}
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||
|
||||
model_buffers = list(model.buffers())
|
||||
for non_persistent_buffer in non_persistent_buffers.values():
|
||||
found_buffer = False
|
||||
for i, model_buffer in enumerate(model_buffers):
|
||||
if torch.equal(non_persistent_buffer, model_buffer):
|
||||
found_buffer = True
|
||||
break
|
||||
|
||||
self.assertTrue(found_buffer)
|
||||
model_buffers.pop(i)
|
||||
|
||||
models_equal = True
|
||||
for layer_name, p1 in model_state_dict.items():
|
||||
if layer_name in loaded_model_state_dict:
|
||||
p2 = loaded_model_state_dict[layer_name]
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
def test_torch_fx(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -2832,8 +2851,6 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_conversion(self):
|
||||
import torch
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@ -2845,7 +2862,7 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
|
||||
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
|
||||
).to(torch_device)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
@ -2859,8 +2876,6 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@ -2871,12 +2886,12 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
@ -2956,8 +2971,6 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@ -2968,12 +2981,12 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
@ -3049,8 +3062,6 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_left_padding(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@ -3060,9 +3071,9 @@ class ModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
@ -3078,7 +3089,10 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
out_fa = model.generate(
|
||||
@ -3092,8 +3106,6 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@ -3103,9 +3115,9 @@ class ModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
@ -3121,7 +3133,10 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
out_fa = model.generate(
|
||||
@ -3130,13 +3145,330 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(torch.allclose(out, out_fa))
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
if not self.all_model_classes[0]._supports_sdpa:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
if torch_device == "cpu" and torch_dtype == "float16":
|
||||
self.skipTest("float16 not supported on cpu")
|
||||
|
||||
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
|
||||
if torch_dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
elif torch_dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
elif torch_dtype == "float32":
|
||||
torch_dtype = torch.float32
|
||||
|
||||
atols = {
|
||||
("cpu", False, torch.float32): 1e-6,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-6,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-6,
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 1e-3,
|
||||
("cuda", True, torch.float32): 1e-6,
|
||||
("cuda", True, torch.bfloat16): 1e-2,
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
rtols = {
|
||||
("cpu", False, torch.float32): 1e-4,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-4,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-4,
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 1e-3,
|
||||
("cuda", True, torch.float32): 1e-4,
|
||||
("cuda", True, torch.bfloat16): 3e-2,
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
|
||||
def get_mean_reldiff(failcase, x, ref, atol, rtol):
|
||||
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
is_encoder_decoder = model.config.is_encoder_decoder
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and model_sdpa.config.model_type != "falcon":
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
|
||||
# but it would be nicer to have an efficient way to use parameterized.expand
|
||||
fail_cases = []
|
||||
for padding_side in ["left", "right"]:
|
||||
for use_mask in [False, True]:
|
||||
for batch_size in [1, 5]:
|
||||
dummy_input = inputs_dict[model.main_input_name]
|
||||
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch_dtype)
|
||||
|
||||
dummy_input = dummy_input[:batch_size]
|
||||
if dummy_input.shape[0] != batch_size:
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
|
||||
extension = torch.rand(
|
||||
batch_size - dummy_input.shape[0],
|
||||
*dummy_input.shape[1:],
|
||||
dtype=torch_dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
|
||||
else:
|
||||
extension = torch.randint(
|
||||
high=5,
|
||||
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
|
||||
dtype=dummy_input.dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
|
||||
|
||||
if not use_mask:
|
||||
dummy_attention_mask = None
|
||||
else:
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
if dummy_attention_mask is None:
|
||||
if is_encoder_decoder:
|
||||
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
|
||||
else:
|
||||
seqlen = dummy_input.shape[-1]
|
||||
dummy_attention_mask = (
|
||||
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
|
||||
)
|
||||
|
||||
dummy_attention_mask = dummy_attention_mask[:batch_size]
|
||||
if dummy_attention_mask.shape[0] != batch_size:
|
||||
extension = torch.ones(
|
||||
batch_size - dummy_attention_mask.shape[0],
|
||||
*dummy_attention_mask.shape[1:],
|
||||
dtype=dummy_attention_mask.dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
|
||||
dummy_attention_mask = dummy_attention_mask.to(torch_device)
|
||||
|
||||
dummy_attention_mask[:] = 1
|
||||
if padding_side == "left":
|
||||
dummy_attention_mask[-1, :-1] = 1
|
||||
dummy_attention_mask[-1, -4:] = 0
|
||||
elif padding_side == "right":
|
||||
dummy_attention_mask[-1, 1:] = 1
|
||||
dummy_attention_mask[-1, :3] = 0
|
||||
|
||||
for enable_kernels in [False, True]:
|
||||
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
|
||||
if is_encoder_decoder:
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size]
|
||||
if decoder_input_ids.shape[0] != batch_size:
|
||||
extension = torch.ones(
|
||||
batch_size - decoder_input_ids.shape[0],
|
||||
*decoder_input_ids.shape[1:],
|
||||
dtype=decoder_input_ids.dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
# TODO: never an `attention_mask` arg here?
|
||||
other_inputs = {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": dummy_attention_mask,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
else:
|
||||
other_inputs = {
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
|
||||
# Otherwise fails for e.g. WhisperEncoderModel
|
||||
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
# TODO: test gradients as well (& for FA2 as well!)
|
||||
with torch.no_grad():
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=enable_kernels,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=enable_kernels,
|
||||
):
|
||||
outputs_eager = model_eager(dummy_input, **other_inputs)
|
||||
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
|
||||
|
||||
logits_eager = (
|
||||
outputs_eager.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_eager.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_sdpa = (
|
||||
outputs_sdpa.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_sdpa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
if torch_device in ["cpu", "cuda"]:
|
||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||
else:
|
||||
atol = 1e-7
|
||||
rtol = 1e-4
|
||||
|
||||
# Masked tokens output slightly deviates - we don't mind that.
|
||||
if use_mask:
|
||||
if padding_side == "left":
|
||||
sub_sdpa = logits_sdpa[:-1]
|
||||
sub_eager = logits_eager[:-1]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
|
||||
sub_sdpa = logits_sdpa[-1, :-4]
|
||||
sub_eager = logits_eager[-1, :-4]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
|
||||
# Testing the padding tokens is not really meaningful but anyway
|
||||
# sub_sdpa = logits_sdpa[-1, -4:]
|
||||
# sub_eager = logits_eager[-1, -4:]
|
||||
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
|
||||
elif padding_side == "right":
|
||||
sub_sdpa = logits_sdpa[:-1]
|
||||
sub_eager = logits_eager[:-1]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
|
||||
sub_sdpa = logits_sdpa[-1, 3:]
|
||||
sub_eager = logits_eager[-1, 3:]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
|
||||
# Testing the padding tokens is not really meaningful but anyway
|
||||
# sub_sdpa = logits_sdpa[-1, :3]
|
||||
# sub_eager = logits_eager[-1, :3]
|
||||
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
|
||||
|
||||
else:
|
||||
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
|
||||
)
|
||||
|
||||
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
max_new_tokens = 30
|
||||
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa:
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
# Just test that a large cache works as expected
|
||||
res_eager = model_eager.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
||||
)
|
||||
|
||||
res_sdpa = model_sdpa.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_use_cache(self):
|
||||
import torch
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@ -3163,7 +3495,7 @@ class ModelTesterMixin:
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
@ -3182,8 +3514,6 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@ -3204,7 +3534,7 @@ class ModelTesterMixin:
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
@ -3282,8 +3612,6 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_from_config(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@ -3291,7 +3619,7 @@ class ModelTesterMixin:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# TODO: to change it in the future with other relevant auto classes
|
||||
fa2_model = AutoModelForCausalLM.from_config(
|
||||
config, use_flash_attention_2=True, torch_dtype=torch.bfloat16
|
||||
config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
@ -3313,7 +3641,7 @@ class ModelTesterMixin:
|
||||
|
||||
model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertFalse(getattr(model_from_pretrained.config, "_flash_attn_2_enabled", False))
|
||||
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
|
||||
|
||||
fa2_correctly_converted = False
|
||||
|
||||
|
@ -60,7 +60,13 @@ from transformers.utils import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available
|
||||
from transformers.utils.import_utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_sdpa_available,
|
||||
is_torchdynamo_available,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@ -1689,3 +1695,158 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
res_compiled = compiled_model(mask, inputs_embeds)
|
||||
|
||||
self.assertTrue(torch.equal(res_non_compiled, res_compiled))
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_unmask_unattended_left_padding(self):
|
||||
attention_mask = torch.Tensor([[0, 0, 1], [1, 1, 1], [0, 1, 1]]).to(torch.int64)
|
||||
|
||||
expanded_mask = torch.Tensor(
|
||||
[
|
||||
[[[0, 0, 0], [0, 0, 0], [0, 0, 1]]],
|
||||
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
|
||||
[[[0, 0, 0], [0, 1, 0], [0, 1, 1]]],
|
||||
]
|
||||
).to(torch.int64)
|
||||
|
||||
reference_output = torch.Tensor(
|
||||
[
|
||||
[[[1, 1, 1], [1, 1, 1], [0, 0, 1]]],
|
||||
[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]],
|
||||
[[[1, 1, 1], [0, 1, 0], [0, 1, 1]]],
|
||||
]
|
||||
).to(torch.int64)
|
||||
|
||||
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=1)
|
||||
|
||||
self.assertTrue(torch.equal(result, reference_output))
|
||||
|
||||
attention_mask = torch.Tensor([[0, 0, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 1]]).to(torch.int64)
|
||||
|
||||
attn_mask_converter = AttentionMaskConverter(is_causal=True)
|
||||
past_key_values_length = 0
|
||||
key_value_length = attention_mask.shape[-1] + past_key_values_length
|
||||
|
||||
expanded_mask = attn_mask_converter.to_4d(
|
||||
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
|
||||
)
|
||||
|
||||
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
|
||||
min_inf = torch.finfo(torch.float32).min
|
||||
reference_output = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[min_inf, min_inf, 0, min_inf, min_inf],
|
||||
[min_inf, min_inf, 0, 0, min_inf],
|
||||
[min_inf, min_inf, 0, 0, 0],
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
[0, min_inf, min_inf, min_inf, min_inf],
|
||||
[0, 0, min_inf, min_inf, min_inf],
|
||||
[0, 0, 0, min_inf, min_inf],
|
||||
[0, 0, 0, 0, min_inf],
|
||||
[0, 0, 0, 0, 0],
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 0, 0],
|
||||
[min_inf, 0, min_inf, min_inf, min_inf],
|
||||
[min_inf, 0, 0, min_inf, min_inf],
|
||||
[min_inf, 0, 0, 0, min_inf],
|
||||
[min_inf, 0, 0, 0, 0],
|
||||
]
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(reference_output, result))
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_unmask_unattended_right_padding(self):
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 0, 0]]).to(torch.int64)
|
||||
|
||||
attn_mask_converter = AttentionMaskConverter(is_causal=True)
|
||||
past_key_values_length = 0
|
||||
key_value_length = attention_mask.shape[-1] + past_key_values_length
|
||||
|
||||
expanded_mask = attn_mask_converter.to_4d(
|
||||
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
|
||||
)
|
||||
|
||||
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
|
||||
|
||||
self.assertTrue(torch.equal(expanded_mask, result))
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_unmask_unattended_random_mask(self):
|
||||
attention_mask = torch.Tensor([[1, 0, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]]).to(torch.int64)
|
||||
|
||||
attn_mask_converter = AttentionMaskConverter(is_causal=True)
|
||||
past_key_values_length = 0
|
||||
key_value_length = attention_mask.shape[-1] + past_key_values_length
|
||||
|
||||
expanded_mask = attn_mask_converter.to_4d(
|
||||
attention_mask, attention_mask.shape[-1], key_value_length=key_value_length, dtype=torch.float32
|
||||
)
|
||||
|
||||
result = AttentionMaskConverter._unmask_unattended(expanded_mask, attention_mask, unmasked_value=0)
|
||||
|
||||
self.assertTrue(torch.equal(expanded_mask, result))
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestAttentionImplementation(unittest.TestCase):
|
||||
def test_error_no_sdpa_available(self):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="sdpa")
|
||||
|
||||
self.assertTrue(
|
||||
"does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention"
|
||||
in str(cm.exception)
|
||||
)
|
||||
|
||||
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
|
||||
|
||||
def test_error_no_flash_available(self):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
_ = AutoModel.from_pretrained(
|
||||
"hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
|
||||
|
||||
def test_error_wrong_attn_implementation(self):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
|
||||
|
||||
self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception))
|
||||
|
||||
def test_not_available_flash(self):
|
||||
if is_flash_attn_2_available():
|
||||
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
|
||||
|
||||
with self.assertRaises(ImportError) as cm:
|
||||
_ = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
|
||||
|
||||
def test_not_available_sdpa(self):
|
||||
if is_torch_sdpa_available():
|
||||
self.skipTest("This test requires torch<=2.0")
|
||||
|
||||
with self.assertRaises(ImportError) as cm:
|
||||
_ = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
|
||||
|
@ -12,11 +12,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import doctest
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
@ -27,6 +27,63 @@ from transformers.testing_utils import require_tf, require_torch, slow
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestDocLists(unittest.TestCase):
|
||||
def test_flash_support_list(self):
|
||||
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
|
||||
doctext = f.read()
|
||||
|
||||
doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
|
||||
doctext = doctext.split("You can request to add FlashAttention-2 support")[0]
|
||||
|
||||
patterns = glob("./src/transformers/models/**/modeling_*.py")
|
||||
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
|
||||
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
|
||||
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
|
||||
archs_supporting_fa2 = []
|
||||
for filename in patterns:
|
||||
with open(filename, "r") as f:
|
||||
text = f.read()
|
||||
|
||||
if "_supports_flash_attn_2 = True" in text:
|
||||
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
|
||||
archs_supporting_fa2.append(model_name)
|
||||
|
||||
for arch in archs_supporting_fa2:
|
||||
if arch not in doctext:
|
||||
raise ValueError(
|
||||
f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
|
||||
)
|
||||
|
||||
def test_sdpa_support_list(self):
|
||||
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
|
||||
doctext = f.read()
|
||||
|
||||
doctext = doctext.split(
|
||||
"For now, Transformers supports inference and training through SDPA for the following architectures:"
|
||||
)[1]
|
||||
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
|
||||
|
||||
patterns = glob("./src/transformers/models/**/modeling_*.py")
|
||||
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
|
||||
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
|
||||
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
|
||||
archs_supporting_sdpa = []
|
||||
for filename in patterns:
|
||||
with open(filename, "r") as f:
|
||||
text = f.read()
|
||||
|
||||
if "_supports_sdpa = True" in text:
|
||||
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
|
||||
archs_supporting_sdpa.append(model_name)
|
||||
|
||||
for arch in archs_supporting_sdpa:
|
||||
if arch not in doctext:
|
||||
raise ValueError(
|
||||
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
|
||||
)
|
||||
|
||||
|
||||
@unittest.skip("Temporarily disable the doc tests.")
|
||||
@require_torch
|
||||
@require_tf
|
||||
|
Loading…
Reference in New Issue
Block a user