Fix AttentionInterface following feedback (#37010)

* up

* typo

* update doc

* Update attention_interface.md
This commit is contained in:
Cyril Vallez 2025-03-28 18:00:35 +01:00 committed by GitHub
parent a86dad56bc
commit 2bea6bf24e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 11 deletions

View File

@ -23,13 +23,13 @@ supported models.
Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping.
By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention)
as well as `eager`, which is simple matrix multiplication without any optimization on top.
as well as `eager`, which is a simple matrix multiplication without any optimization on top.
This is the setting you can usually choose when instantiating a model:
```python
from transformers import AutoModelForCausalLM
model_id = "meta-llama/Llama-3.2-1B
model_id = "meta-llama/Llama-3.2-1B"
# Here, using flash attention as an example
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
@ -43,7 +43,7 @@ from transformers import AutoModelForCausalLM, AttentionInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
import torch
model_id = "meta-llama/Llama-3.2-1B
model_id = "meta-llama/Llama-3.2-1B"
def my_new_sdpa(*args, **kwargs):
print("I just entered the attention computation")
@ -56,7 +56,7 @@ model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_n
model(torch.ones(1, 5, dtype=int))
```
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times.
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times).
## Dynamically switching attention function
@ -70,12 +70,12 @@ model(torch.ones(1, 5, dtype=int))
```
and it will stop printing the statements, as it now uses the `sdpa` attention.
This allows to quickly change attention function, without needing to reload the model!
This allows to quickly change an attention function, without needing to reload the model!
## What about new args needed in my custom function?
## What about new args needed in my custom attention function?
But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
`AttentionInterface` propagates kwargs all the way to the Attention layers, and to the attention function used. That way,
`AttentionInterface` propagate kwargs all the way to the Attention layers, and to the used attention function. That way,
you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.
```python
@ -103,4 +103,26 @@ model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="cust
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)
```
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
## Accessing current available implementations
Most of the time, you will simply need to `register` a new function. If, however, you need to access an existing one,
and/or perform a few checks, the prefered way is to use the global `ALL_ATTENTION_FUNCTIONS`. It behaves the same way you
would expect from a usual Python dictionary:
```python
>>> from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
>>> list(ALL_ATTENTION_FUNCTIONS.keys())
>>> ['flash_attention_2', 'flex_attention', 'sdpa']
>>> ALL_ATTENTION_FUNCTIONS["sdpa"]
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
>>> ALL_ATTENTION_FUNCTIONS.get("sdpa", None)
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
# You can also globally `register` a new function directly on it
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
```

View File

@ -5917,7 +5917,7 @@ class AttentionInterface(MutableMapping):
"""
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance.
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
"""
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
@ -5946,7 +5946,7 @@ class AttentionInterface(MutableMapping):
def __iter__(self):
# Ensure we use all keys, with the overwritten ones on top
return iter(self._global_mapping.update(self._local_mapping))
return iter({**self._global_mapping, **self._local_mapping})
def __len__(self):
return len(self._global_mapping.keys() | self._local_mapping.keys())
@ -5956,7 +5956,7 @@ class AttentionInterface(MutableMapping):
cls._global_mapping.update({key: value})
def valid_keys(self) -> List[str]:
return list(self._global_mapping.keys() | self._local_mapping.keys())
return list(self.keys())
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones