transformers/tests/generation/test_paged_attention.py
Yuanyuan Chen 11b670a282
Fix run_slow (#38314)
Signed-off-by: cyy <cyyever@outlook.com>
2025-05-23 10:18:30 +02:00

87 lines
4.1 KiB
Python

import time
import unittest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_flash_attn, require_torch_gpu, slow
_TEST_PROMPTS = [
"A man is a walking his dog down the street, and a the turn he sees",
"Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an",
"A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?",
"Please fill in the form to",
"For safety reasons, the train is stopped in the middle of the",
]
_EXPECTED_OUTPUTS = [
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes up a conversation, and they quickly discover that they have a lot in common. They exchange numbers and",
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n## Step 2: Determine the taste and nutritional value of the fruit\nThe fruit is described as sweet",
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer is not a straightforward one, and it requires some lateral thinking to arrive at the correct solution.",
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]\n[Your Message]\n\nWe are looking forward to hearing from you!\n\n[Insert Contact Information]\n\nNote:",
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train travel in 30 minutes?\n## Step 1: Convert the speed from km/h to km/min",
]
@slow
@require_torch_gpu
@require_flash_attn
class TestBatchGeneration(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3b-Instruct", torch_dtype="bfloat16", device_map="auto"
).eval()
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left")
if cls.tokenizer.pad_token is None:
cls.tokenizer.pad_token = cls.tokenizer.eos_token
cls.model.config.pad_token_id = cls.model.config.eos_token_id
cls.model.use_cache = False
@parameterized.expand(
[
("eager_paged", 64, 128, 64),
("sdpa_paged", 32, 256, 128),
("paged_attention", 16, 512, 256),
("flex_paged", 64, 128, 64),
]
)
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
self.model.config.attn_implementation = attn_impl
generation_config = GenerationConfig(
max_new_tokens=50,
top_k=0,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False,
num_blocks=num_blocks,
block_size=block_size,
max_batch_tokens=max_batch_tokens,
)
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
batch_inputs = list(tokenized["input_ids"])
start = time.time()
batch_outputs = self.model.generate_batch(
inputs=batch_inputs,
generation_config=generation_config,
)
end = time.time()
print(
f"\n[{attn_impl}] Batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
)
for i, req_id in enumerate(batch_outputs):
generated = self.tokenizer.decode(batch_outputs[req_id].static_outputs, skip_special_tokens=False).strip()
expected = _EXPECTED_OUTPUTS[i].strip()
self.assertTrue(
generated.startswith(expected),
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}",
)