mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
This commit is contained in:
parent
cb35cf7a40
commit
7d0e91deb2
@ -9,7 +9,7 @@ from transformers.generation import GenerationConfig
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-3b-Instruct"
|
||||
model_id = "meta-llama/Llama-3.2-1b-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
).eval()
|
||||
@ -20,14 +20,15 @@ generation_config = GenerationConfig(
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=2048,
|
||||
block_size=128,
|
||||
num_blocks=128,
|
||||
block_size=32,
|
||||
do_sample=True,
|
||||
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
|
||||
max_batch_tokens=128, # Maximum number of tokens to process in a single batch
|
||||
scheduler="prefill_first",
|
||||
)
|
||||
|
||||
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
train_dataset = train_dataset.select(range(3))
|
||||
|
||||
# --- Example 1: Simple Version using generate_batch ---
|
||||
print("--- Running CB Generation Example ---")
|
||||
@ -67,7 +68,6 @@ print("--- Finished CB Generation Example ---\n\n")
|
||||
|
||||
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")
|
||||
|
||||
|
||||
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version
|
||||
|
||||
# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512)
|
||||
@ -107,3 +107,4 @@ print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")
|
||||
# print(f" Output: {output_text}")
|
||||
# print("-" * 20)
|
||||
# print("--- Finished Simple Batch Generation Example ---\n\n")
|
||||
|
||||
|
@ -2,6 +2,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel
|
||||
|
||||
paged_attention_kernel = get_kernel("kernels-community/paged-attention")
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
@ -15,7 +18,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def sdpa_attention_paged_forward(
|
||||
def sdpa_attention_paged_forward__(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -47,5 +50,128 @@ def sdpa_attention_paged_forward(
|
||||
is_causal=False,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def sdpa_attention_paged_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
dropout: float = 0.0,
|
||||
scaling: Optional[float] = None,
|
||||
is_causal: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
cache = kwargs.pop("cache", None)
|
||||
if cache is not None:
|
||||
key, value = cache.update(key, value, module.layer_idx, **kwargs)
|
||||
if hasattr(module, "num_key_value_groups"):
|
||||
key = repeat_kv(key, module.num_key_value_groups)
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
# Get parameters
|
||||
batch_size, seq_len, num_heads, head_size = query.shape
|
||||
num_kv_heads = key.shape[2]
|
||||
block_size = kwargs.get("block_size", 32)
|
||||
max_seq_len = kwargs.get("max_seqlen_k", seq_len)
|
||||
x = 16 # Key cache formatting parameter
|
||||
|
||||
# For paged attention, we need to handle each sequence separately
|
||||
# Reshape query to [batch_size, num_heads, head_size] - assuming seq_len=1 for generation
|
||||
if seq_len == 1:
|
||||
# Generation case - single token per batch
|
||||
query_reshaped = query.squeeze(1) # [batch_size, num_heads, head_size]
|
||||
else:
|
||||
# Prefill case - need to handle multiple tokens
|
||||
query_reshaped = query.reshape(batch_size * seq_len, num_heads, head_size)
|
||||
batch_size = batch_size * seq_len
|
||||
|
||||
# Calculate number of blocks needed
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
num_blocks = batch_size * max_num_blocks_per_seq
|
||||
|
||||
key_cache = torch.zeros(num_blocks, num_kv_heads, head_size // x, block_size, x, device=query.device, dtype=key.dtype)
|
||||
value_cache = torch.zeros(num_blocks, num_kv_heads, head_size, block_size, device=query.device, dtype=value.dtype)
|
||||
|
||||
key_input = key.reshape(-1, num_kv_heads, head_size).contiguous()
|
||||
value_input = value.reshape(-1, num_kv_heads, head_size).contiguous()
|
||||
|
||||
slot_mapping = torch.arange(key_input.shape[0], device=query.device)
|
||||
|
||||
paged_attention_kernel.reshape_and_cache(
|
||||
key_input,
|
||||
value_input,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kwargs.get("kv_cache_dtype", "auto"),
|
||||
kwargs.get("k_scale", torch.tensor(1.0, device=query.device)),
|
||||
kwargs.get("v_scale", torch.tensor(1.0, device=query.device)),
|
||||
)
|
||||
|
||||
# Create proper sequence lengths and block tables
|
||||
seq_lens = kwargs.get("cumulative_seqlens_k", None)
|
||||
if seq_lens is None:
|
||||
# Default: assume each sequence has seq_len tokens
|
||||
seq_lens = torch.full((batch_size,), seq_len, device=query.device, dtype=torch.int32)
|
||||
|
||||
block_tables = kwargs.get("block_tables", None)
|
||||
# if block_tables is None:
|
||||
# # Create default block tables
|
||||
# block_tables_lst = []
|
||||
# for i in range(batch_size):
|
||||
# seq_length = seq_lens[i].item() if seq_lens is not None else seq_len
|
||||
# num_blocks_needed = (seq_length + block_size - 1) // block_size
|
||||
# block_table = []
|
||||
|
||||
# for j in range(max_num_blocks_per_seq):
|
||||
# if j < num_blocks_needed:
|
||||
# block_table.append(i * max_num_blocks_per_seq + j)
|
||||
# else:
|
||||
# block_table.append(0) # Padding
|
||||
|
||||
# block_tables_lst.append(block_table)
|
||||
|
||||
# block_tables = torch.tensor(block_tables_lst, dtype=torch.int32, device=query.device)
|
||||
|
||||
# Prepare query and output tensors
|
||||
query_reshaped = query_reshaped.contiguous()
|
||||
attn_output = torch.empty_like(query_reshaped, device=query.device)
|
||||
|
||||
# Ensure proper scaling
|
||||
scale = scaling
|
||||
if scale is None:
|
||||
scale = torch.tensor(1.0 / (head_size ** 0.5), device=query.device)
|
||||
elif not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale, device=query.device)
|
||||
|
||||
torch.mps.synchronize()
|
||||
paged_attention_kernel.paged_attention_v1(
|
||||
attn_output,
|
||||
query_reshaped,
|
||||
key_cache, # Now using proper cache format
|
||||
value_cache, # Now using proper cache format
|
||||
num_kv_heads=num_kv_heads,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
block_size=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
kv_cache_dtype=kwargs.get("kv_cache_dtype", "auto"),
|
||||
scale=scale,
|
||||
k_scale=kwargs.get("k_scale", torch.tensor(1.0, device=query.device)),
|
||||
v_scale=kwargs.get("v_scale", torch.tensor(1.0, device=query.device)),
|
||||
alibi_slopes=kwargs.get("alibi_slopes", None),
|
||||
)
|
||||
|
||||
# Reshape output back to original format
|
||||
if seq_len == 1:
|
||||
attn_output = attn_output.unsqueeze(1) # Add seq_len dimension back
|
||||
else:
|
||||
attn_output = attn_output.reshape(query.shape[0], seq_len, num_heads, head_size)
|
||||
|
||||
attn_output = attn_output.contiguous()
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, None
|
||||
|
Loading…
Reference in New Issue
Block a user