diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 9aaa836f7ba..a8600c839cd 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -9,7 +9,7 @@ from transformers.generation import GenerationConfig torch.set_float32_matmul_precision("high") -model_id = "meta-llama/Llama-3.2-3b-Instruct" +model_id = "meta-llama/Llama-3.2-1b-Instruct" model = AutoModelForCausalLM.from_pretrained( model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" ).eval() @@ -20,14 +20,15 @@ generation_config = GenerationConfig( eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, use_cache=False, - num_blocks=2048, - block_size=128, + num_blocks=128, + block_size=32, do_sample=True, - max_batch_tokens=1024, # Maximum number of tokens to process in a single batch + max_batch_tokens=128, # Maximum number of tokens to process in a single batch scheduler="prefill_first", ) train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") +train_dataset = train_dataset.select(range(3)) # --- Example 1: Simple Version using generate_batch --- print("--- Running CB Generation Example ---") @@ -67,7 +68,6 @@ print("--- Finished CB Generation Example ---\n\n") print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds") - # train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version # tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512) @@ -107,3 +107,4 @@ print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds") # print(f" Output: {output_text}") # print("-" * 20) # print("--- Finished Simple Batch Generation Example ---\n\n") + diff --git a/src/transformers/integrations/sdpa_paged.py b/src/transformers/integrations/sdpa_paged.py index 558f4a6f715..d42095d8952 100644 --- a/src/transformers/integrations/sdpa_paged.py +++ b/src/transformers/integrations/sdpa_paged.py @@ -2,6 +2,9 @@ from typing import Optional import torch +from kernels import get_kernel + +paged_attention_kernel = get_kernel("kernels-community/paged-attention") def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -15,7 +18,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def sdpa_attention_paged_forward( +def sdpa_attention_paged_forward__( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -47,5 +50,128 @@ def sdpa_attention_paged_forward( is_causal=False, ) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None + + +def sdpa_attention_paged_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, +) -> tuple[torch.Tensor, None]: + cache = kwargs.pop("cache", None) + if cache is not None: + key, value = cache.update(key, value, module.layer_idx, **kwargs) + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + # Get parameters + batch_size, seq_len, num_heads, head_size = query.shape + num_kv_heads = key.shape[2] + block_size = kwargs.get("block_size", 32) + max_seq_len = kwargs.get("max_seqlen_k", seq_len) + x = 16 # Key cache formatting parameter + + # For paged attention, we need to handle each sequence separately + # Reshape query to [batch_size, num_heads, head_size] - assuming seq_len=1 for generation + if seq_len == 1: + # Generation case - single token per batch + query_reshaped = query.squeeze(1) # [batch_size, num_heads, head_size] + else: + # Prefill case - need to handle multiple tokens + query_reshaped = query.reshape(batch_size * seq_len, num_heads, head_size) + batch_size = batch_size * seq_len + + # Calculate number of blocks needed + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + num_blocks = batch_size * max_num_blocks_per_seq + + key_cache = torch.zeros(num_blocks, num_kv_heads, head_size // x, block_size, x, device=query.device, dtype=key.dtype) + value_cache = torch.zeros(num_blocks, num_kv_heads, head_size, block_size, device=query.device, dtype=value.dtype) + + key_input = key.reshape(-1, num_kv_heads, head_size).contiguous() + value_input = value.reshape(-1, num_kv_heads, head_size).contiguous() + + slot_mapping = torch.arange(key_input.shape[0], device=query.device) + + paged_attention_kernel.reshape_and_cache( + key_input, + value_input, + key_cache, + value_cache, + slot_mapping, + kwargs.get("kv_cache_dtype", "auto"), + kwargs.get("k_scale", torch.tensor(1.0, device=query.device)), + kwargs.get("v_scale", torch.tensor(1.0, device=query.device)), + ) + + # Create proper sequence lengths and block tables + seq_lens = kwargs.get("cumulative_seqlens_k", None) + if seq_lens is None: + # Default: assume each sequence has seq_len tokens + seq_lens = torch.full((batch_size,), seq_len, device=query.device, dtype=torch.int32) + + block_tables = kwargs.get("block_tables", None) + # if block_tables is None: + # # Create default block tables + # block_tables_lst = [] + # for i in range(batch_size): + # seq_length = seq_lens[i].item() if seq_lens is not None else seq_len + # num_blocks_needed = (seq_length + block_size - 1) // block_size + # block_table = [] + + # for j in range(max_num_blocks_per_seq): + # if j < num_blocks_needed: + # block_table.append(i * max_num_blocks_per_seq + j) + # else: + # block_table.append(0) # Padding + + # block_tables_lst.append(block_table) + + # block_tables = torch.tensor(block_tables_lst, dtype=torch.int32, device=query.device) + + # Prepare query and output tensors + query_reshaped = query_reshaped.contiguous() + attn_output = torch.empty_like(query_reshaped, device=query.device) + + # Ensure proper scaling + scale = scaling + if scale is None: + scale = torch.tensor(1.0 / (head_size ** 0.5), device=query.device) + elif not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=query.device) + + torch.mps.synchronize() + paged_attention_kernel.paged_attention_v1( + attn_output, + query_reshaped, + key_cache, # Now using proper cache format + value_cache, # Now using proper cache format + num_kv_heads=num_kv_heads, + block_tables=block_tables, + seq_lens=seq_lens, + block_size=block_size, + max_seq_len=max_seq_len, + kv_cache_dtype=kwargs.get("kv_cache_dtype", "auto"), + scale=scale, + k_scale=kwargs.get("k_scale", torch.tensor(1.0, device=query.device)), + v_scale=kwargs.get("v_scale", torch.tensor(1.0, device=query.device)), + alibi_slopes=kwargs.get("alibi_slopes", None), + ) + + # Reshape output back to original format + if seq_len == 1: + attn_output = attn_output.unsqueeze(1) # Add seq_len dimension back + else: + attn_output = attn_output.reshape(query.shape[0], seq_len, num_heads, head_size) + + attn_output = attn_output.contiguous() + attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None