mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix AttentionInterface following feedback (#37010)
* up * typo * update doc * Update attention_interface.md
This commit is contained in:
parent
a86dad56bc
commit
2bea6bf24e
@ -23,13 +23,13 @@ supported models.
|
||||
Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping.
|
||||
By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
|
||||
[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention)
|
||||
as well as `eager`, which is simple matrix multiplication without any optimization on top.
|
||||
as well as `eager`, which is a simple matrix multiplication without any optimization on top.
|
||||
This is the setting you can usually choose when instantiating a model:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B
|
||||
model_id = "meta-llama/Llama-3.2-1B"
|
||||
|
||||
# Here, using flash attention as an example
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
|
||||
@ -43,7 +43,7 @@ from transformers import AutoModelForCausalLM, AttentionInterface
|
||||
from transformers.integrations.sdpa_attention import sdpa_attention_forward
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B
|
||||
model_id = "meta-llama/Llama-3.2-1B"
|
||||
|
||||
def my_new_sdpa(*args, **kwargs):
|
||||
print("I just entered the attention computation")
|
||||
@ -56,7 +56,7 @@ model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_n
|
||||
model(torch.ones(1, 5, dtype=int))
|
||||
```
|
||||
|
||||
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times.
|
||||
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times).
|
||||
|
||||
## Dynamically switching attention function
|
||||
|
||||
@ -70,12 +70,12 @@ model(torch.ones(1, 5, dtype=int))
|
||||
```
|
||||
|
||||
and it will stop printing the statements, as it now uses the `sdpa` attention.
|
||||
This allows to quickly change attention function, without needing to reload the model!
|
||||
This allows to quickly change an attention function, without needing to reload the model!
|
||||
|
||||
## What about new args needed in my custom function?
|
||||
## What about new args needed in my custom attention function?
|
||||
|
||||
But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
|
||||
`AttentionInterface` propagates kwargs all the way to the Attention layers, and to the attention function used. That way,
|
||||
`AttentionInterface` propagate kwargs all the way to the Attention layers, and to the used attention function. That way,
|
||||
you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.
|
||||
|
||||
```python
|
||||
@ -103,4 +103,26 @@ model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="cust
|
||||
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)
|
||||
```
|
||||
|
||||
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
|
||||
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
|
||||
|
||||
## Accessing current available implementations
|
||||
|
||||
Most of the time, you will simply need to `register` a new function. If, however, you need to access an existing one,
|
||||
and/or perform a few checks, the prefered way is to use the global `ALL_ATTENTION_FUNCTIONS`. It behaves the same way you
|
||||
would expect from a usual Python dictionary:
|
||||
|
||||
```python
|
||||
>>> from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
>>> list(ALL_ATTENTION_FUNCTIONS.keys())
|
||||
>>> ['flash_attention_2', 'flex_attention', 'sdpa']
|
||||
|
||||
>>> ALL_ATTENTION_FUNCTIONS["sdpa"]
|
||||
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
|
||||
|
||||
>>> ALL_ATTENTION_FUNCTIONS.get("sdpa", None)
|
||||
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
|
||||
|
||||
# You can also globally `register` a new function directly on it
|
||||
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
|
||||
```
|
@ -5917,7 +5917,7 @@ class AttentionInterface(MutableMapping):
|
||||
"""
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
|
||||
it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance.
|
||||
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
|
||||
"""
|
||||
|
||||
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
||||
@ -5946,7 +5946,7 @@ class AttentionInterface(MutableMapping):
|
||||
|
||||
def __iter__(self):
|
||||
# Ensure we use all keys, with the overwritten ones on top
|
||||
return iter(self._global_mapping.update(self._local_mapping))
|
||||
return iter({**self._global_mapping, **self._local_mapping})
|
||||
|
||||
def __len__(self):
|
||||
return len(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
@ -5956,7 +5956,7 @@ class AttentionInterface(MutableMapping):
|
||||
cls._global_mapping.update({key: value})
|
||||
|
||||
def valid_keys(self) -> List[str]:
|
||||
return list(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
return list(self.keys())
|
||||
|
||||
|
||||
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
||||
|
Loading…
Reference in New Issue
Block a user