mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
87 lines
4.1 KiB
Python
87 lines
4.1 KiB
Python
import time
|
|
import unittest
|
|
|
|
from parameterized import parameterized
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
from transformers.testing_utils import require_flash_attn, require_torch_gpu, slow
|
|
|
|
|
|
_TEST_PROMPTS = [
|
|
"A man is a walking his dog down the street, and a the turn he sees",
|
|
"Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an",
|
|
"A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?",
|
|
"Please fill in the form to",
|
|
"For safety reasons, the train is stopped in the middle of the",
|
|
]
|
|
|
|
_EXPECTED_OUTPUTS = [
|
|
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes up a conversation, and they quickly discover that they have a lot in common. They exchange numbers and",
|
|
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n## Step 2: Determine the taste and nutritional value of the fruit\nThe fruit is described as sweet",
|
|
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer is not a straightforward one, and it requires some lateral thinking to arrive at the correct solution.",
|
|
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]\n[Your Message]\n\nWe are looking forward to hearing from you!\n\n[Insert Contact Information]\n\nNote:",
|
|
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train travel in 30 minutes?\n## Step 1: Convert the speed from km/h to km/min",
|
|
]
|
|
|
|
|
|
@slow
|
|
@require_torch_gpu
|
|
@require_flash_attn
|
|
class TestBatchGeneration(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = AutoModelForCausalLM.from_pretrained(
|
|
"meta-llama/Llama-3.2-3b-Instruct", torch_dtype="bfloat16", device_map="auto"
|
|
).eval()
|
|
|
|
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left")
|
|
|
|
if cls.tokenizer.pad_token is None:
|
|
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
|
cls.model.config.pad_token_id = cls.model.config.eos_token_id
|
|
|
|
cls.model.use_cache = False
|
|
|
|
@parameterized.expand(
|
|
[
|
|
("eager_paged", 64, 128, 64),
|
|
("sdpa_paged", 32, 256, 128),
|
|
("paged_attention", 16, 512, 256),
|
|
("flex_paged", 64, 128, 64),
|
|
]
|
|
)
|
|
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
|
self.model.config.attn_implementation = attn_impl
|
|
|
|
generation_config = GenerationConfig(
|
|
max_new_tokens=50,
|
|
top_k=0,
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
use_cache=False,
|
|
num_blocks=num_blocks,
|
|
block_size=block_size,
|
|
max_batch_tokens=max_batch_tokens,
|
|
)
|
|
|
|
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
|
|
batch_inputs = list(tokenized["input_ids"])
|
|
|
|
start = time.time()
|
|
batch_outputs = self.model.generate_batch(
|
|
inputs=batch_inputs,
|
|
generation_config=generation_config,
|
|
)
|
|
end = time.time()
|
|
print(
|
|
f"\n[{attn_impl}] Batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
|
|
)
|
|
|
|
for i, req_id in enumerate(batch_outputs):
|
|
generated = self.tokenizer.decode(batch_outputs[req_id].static_outputs, skip_special_tokens=False).strip()
|
|
expected = _EXPECTED_OUTPUTS[i].strip()
|
|
self.assertTrue(
|
|
generated.startswith(expected),
|
|
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}",
|
|
)
|