Commit Graph

1 Commits

Author SHA1 Message Date
EduardDurech
a2eb75c891
Support for Flash Attention 3 (#38972)
* Support `flash_attn_3`
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged

* Add tests for Flash Attention 2 and 3 parity

* ci fix

* FA2 compatibiity
- `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids`
- Remove bettertransformer check in Flash Attention 3
- Merge tests
- Add licensing

* ci fix

* Test naming consistency

* ci fix

* Deprecation warning for `prepare_fa2_from_position_ids`

* ci fix
2025-06-25 14:39:27 +02:00