mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00

* Support `flash_attn_3` Implements fwd and tests for Flash Attention 3 https://github.com/Dao-AILab/flash-attention/commits/main/hopper - Includes checks for dropout>0 and ALiBi in `modeling_utils.PreTrainedModel._check_and_enable_flash_attn_3` (Dropout will likely be supported soon, so this will need to be updated and `modeling_flash_attention_utils._flash_attention_forward` at the `if _IS_FLASH_ATTN_3_AVAILABLE: ...` An example Llama implementation is included in `modeling_llama.py` but other models would still need to be updated Based on https://github.com/huggingface/transformers/pull/36190 which has model implementations and examples which could be merged * Add tests for Flash Attention 2 and 3 parity * ci fix * FA2 compatibiity - `_prepare_flash_attention_from_position_ids` ->`prepare_fa2_from_position_ids` - Remove bettertransformer check in Flash Attention 3 - Merge tests - Add licensing * ci fix * Test naming consistency * ci fix * Deprecation warning for `prepare_fa2_from_position_ids` * ci fix
145 lines
5.6 KiB
Python
145 lines
5.6 KiB
Python
# Copyright 2025 Eduard Durech and SGLang team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Usage:
|
|
# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
|
|
|
|
import unittest
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow
|
|
|
|
|
|
class FlashAttentionParityTest(unittest.TestCase):
|
|
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
|
|
def _lcs(self, X, Y):
|
|
m = len(X)
|
|
n = len(Y)
|
|
L = [[0] * (n + 1) for _ in range(m + 1)]
|
|
|
|
for i in range(m + 1):
|
|
for j in range(n + 1):
|
|
if i == 0 or j == 0:
|
|
L[i][j] = 0
|
|
elif X[i - 1] == Y[j - 1]:
|
|
L[i][j] = L[i - 1][j - 1] + 1
|
|
else:
|
|
L[i][j] = max(L[i - 1][j], L[i][j - 1])
|
|
|
|
return L[m][n]
|
|
|
|
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
|
|
def _calculate_rouge_l(self, output_strs_list1, output_strs_list2):
|
|
rouge_l_scores = []
|
|
|
|
for s1, s2 in zip(output_strs_list1, output_strs_list2):
|
|
lcs_len = self._lcs(s1, s2)
|
|
precision = lcs_len / len(s1) if len(s1) > 0 else 0
|
|
recall = lcs_len / len(s2) if len(s2) > 0 else 0
|
|
if precision + recall > 0:
|
|
fmeasure = (2 * precision * recall) / (precision + recall)
|
|
else:
|
|
fmeasure = 0.0
|
|
rouge_l_scores.append(fmeasure)
|
|
|
|
return rouge_l_scores
|
|
|
|
def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5):
|
|
for _ in range(n_warmup):
|
|
model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
torch.cuda.synchronize()
|
|
|
|
start_time = torch.cuda.Event(enable_timing=True)
|
|
end_time = torch.cuda.Event(enable_timing=True)
|
|
|
|
start_time.record()
|
|
for _ in range(n_runs):
|
|
model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
end_time.record()
|
|
torch.cuda.synchronize()
|
|
|
|
return start_time.elapsed_time(end_time) / n_runs
|
|
|
|
@pytest.mark.flash_attn_3_test
|
|
@require_torch_gpu
|
|
@require_flash_attn
|
|
@require_flash_attn_3
|
|
@slow
|
|
def test_flash_attention_2_3_parity(self):
|
|
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
|
prompt = "The ETH AI Center is"
|
|
|
|
# 1. Load FA2 model and tokenizer
|
|
model_2 = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
).to("cuda")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
# 2. Load FA3 model
|
|
try:
|
|
model_3 = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_3",
|
|
).to("cuda")
|
|
except (ValueError, ImportError) as e:
|
|
pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}")
|
|
|
|
# 3. Generate with both models
|
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
|
|
|
with torch.no_grad():
|
|
output_2 = model_2.generate(
|
|
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
|
|
)
|
|
output_3 = model_3.generate(
|
|
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
|
|
)
|
|
|
|
# 4. Correctness check
|
|
# 4a. Logits
|
|
logits_2 = torch.stack(output_2.scores)
|
|
logits_3 = torch.stack(output_3.scores)
|
|
torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3)
|
|
logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1)
|
|
logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1)
|
|
max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item()
|
|
|
|
# 4b. Generated text
|
|
text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True)
|
|
text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True)
|
|
rouge_score = self._calculate_rouge_l([text_2], [text_3])[0]
|
|
assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})"
|
|
|
|
# 5. Performance check
|
|
with torch.no_grad():
|
|
time_2 = self._benchmark_generation(model_2, inputs)
|
|
time_3 = self._benchmark_generation(model_3, inputs)
|
|
|
|
print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---")
|
|
print(f"Prompt: '{prompt}'")
|
|
print(f"Generated text with Flash Attention 2: {text_2}")
|
|
print(f"Generated text with Flash Attention 3: {text_3}")
|
|
print(f"ROUGE-L: {rouge_score}")
|
|
print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}")
|
|
print(f"Flash Attention 2 latency: {time_2:.2f} ms")
|
|
print(f"Flash Attention 3 latency: {time_3:.2f} ms")
|
|
print(f"Speed-up: {time_2 / time_3:.2f}x")
|
|
print("---")
|