transformers/tests/generation
EduardDurech a2eb75c891
Support for Flash Attention 3 (#38972)
* Support `flash_attn_3`
Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper

- Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...`

An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated

Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged

* Add tests for Flash Attention 2 and 3 parity

* ci fix

* FA2 compatibiity
- `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids`
- Remove bettertransformer check in Flash Attention 3
- Merge tests
- Add licensing

* ci fix

* Test naming consistency

* ci fix

* Deprecation warning for `prepare_fa2_from_position_ids`

* ci fix
2025-06-25 14:39:27 +02:00
..
__init__.py [Test refactor 1/5] Per-folder tests reorganization (#15725) 2022-02-23 15:46:28 -05:00
test_beam_constraints.py Use Python 3.9 syntax in tests (#37343) 2025-04-08 14:12:08 +02:00
test_beam_search.py Use Python 3.9 syntax in tests (#37343) 2025-04-08 14:12:08 +02:00
test_candidate_generator.py prune LM Head for USD (#36695) 2025-04-08 16:44:10 +01:00
test_configuration_utils.py 🚨🚨🚨 [pipelines] update defaults in pipelines that can generate (#38129) 2025-05-19 18:02:06 +01:00
test_flash_attention_parity.py Support for Flash Attention 3 (#38972) 2025-06-25 14:39:27 +02:00
test_fsdp.py enable generation fsdp/utils cases on XPU (#38009) 2025-05-09 20:52:41 +00:00
test_logits_process.py Allow Exclusion of Input IDs from RepetitionPenaltyLogitsProcessor (#37625) 2025-04-21 15:46:05 +01:00
test_paged_attention.py enable more test cases on xpu (#38572) 2025-06-06 09:29:51 +02:00
test_stopping_criteria.py Use Python 3.9 syntax in tests (#37343) 2025-04-08 14:12:08 +02:00
test_streamers.py Use Python 3.9 syntax in tests (#37343) 2025-04-08 14:12:08 +02:00
test_utils.py Support for Flash Attention 3 (#38972) 2025-06-25 14:39:27 +02:00