From 1c2f36b480e02c9027d2523746d34e27b39e01a4 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 20 May 2025 15:22:52 +0100 Subject: [PATCH] parallelism goes brrr (#37877) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * accept custom device_mesh * fix device_map * assert that num_heads % tp_size == 0 * todo. * ReplicateParallel * handle tied weights * handle dtensor in save_pretrained with safe_serialization * tp test works * doesnt work * fix shard_and_distribute_module's rank should be local_rank * tp=4 is correct * dp+tp is broken * todo allreduce with dtensors on another dim is annoying * workaround to sync dp grads when using dtensors * loading a checkpoint works * wandb and compare losses with different tp/dp * cleaning * cleaning * . * . * logs * CP2 DP2 no mask works after commenting attn_mask and is_causal from scaled_dot_product_attention * DP=2 TP=2 now works even with tied embeddings * model.parameters() and model.module.parameters() are empty.. * reformat sanity_check_tensor_sync * set atol=1e-4 for CP to pass * try populate _parameters from named_modules * refactors TP2 DP2 works CP2 DP2 works * is_causal=True and pack sequences, no attn mask, and preshuffle dataset * fix packing * CP=4 doesn't work * fix labels and position_ids for CP * DP CP works with transformers 🥳🥳🥳 * refactor * add example cp * fixup * revert sdpa changes * example cleared * add CP, DP to the mesh init * nit * clean * use `ALL_PARALLEL_STYLES` * style * FSDP works * log on 1 rank * . * fix? * FSDP1 also has .parameters() bug * reported gradnorm when using FSDP1 is wrong, but loss is correct so it's okay * . * style and fixup * move stuff around * fix tests * style * let's make it a check * warning should be an info --------- Co-authored-by: Arthur Zucker --- examples/3D_parallel.py | 422 ++++++++++ examples/pytorch/3d_parallel_checks.py | 780 ++++++++++++++++++ examples/pytorch/context_parallel.py | 94 +++ src/transformers/integrations/__init__.py | 4 +- .../integrations/tensor_parallel.py | 274 ++++-- src/transformers/modeling_utils.py | 75 +- 6 files changed, 1514 insertions(+), 135 deletions(-) create mode 100644 examples/3D_parallel.py create mode 100644 examples/pytorch/3d_parallel_checks.py create mode 100644 examples/pytorch/context_parallel.py diff --git a/examples/3D_parallel.py b/examples/3D_parallel.py new file mode 100644 index 00000000000..d56e63bc68f --- /dev/null +++ b/examples/3D_parallel.py @@ -0,0 +1,422 @@ +""": +This script is used to test training a model using Tensor Parallelism and Data Parallelism. + +Usage: +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_VISIBLE_DEVICES=4,5,6,7 +export CUDA_VISIBLE_DEVICES=5,6,7 +TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py +CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py +CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py +DP_SIZE=2 CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=8 examples/3D_parallel.py + +TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py +TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py +TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py +IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=localhost:29504 examples/3D_parallel.py +ocalhost:29504 test_train.py +""" + +import logging +import os +from contextlib import nullcontext +from typing import Iterable + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import torch.optim as optim +import wandb +from datasets import load_dataset +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.experimental import context_parallel +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from transformers import AutoModelForCausalLM, AutoTokenizer + + +# torch.use_deterministic_algorithms(True) +torch.backends.cudnn.deterministic = True + +# Set up logging +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + +# from torch.distributed.tensor.experimental._attention import set_rotate_method + +# set_rotate_method("alltoall") # CP rotate shards using all-to-all + + +def main(): + tp_size = int(os.environ.get("TP_SIZE", 1)) + dp_size = int(os.environ.get("DP_SIZE", 1)) + cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration + sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP + # sdpa_backend = SDPBackend.MATH # For CP + global_batch_size = 8 # Desired global batch size + seq_len = 1024 # Sequence length + num_train_steps = 10000 # Number of training steps + LR = 1e-5 + model_name = "HuggingFaceTB/SmolLM2-1.7B" + # model_name = "unsloth/Llama-3.2-1B" + + CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}" + + # Initialize distributed environment + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + dist.init_process_group("nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + + assert world_size == tp_size * dp_size * cp_size, ( + f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})" + ) + + mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size) + world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp")) + tp_mesh = world_mesh["tp"] + dp_mesh = world_mesh["dp"] + cp_mesh = world_mesh["cp"] + world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + logger.info(f"Created DeviceMesh: {world_mesh}") + logger.info( + f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}" + ) + + if dist.get_rank() == 0: + wandb.init( + project="tp_dp_test", + config={ + "tp_size": tp_size, + "dp_size": dp_size, + "cp_size": cp_size, + "global_batch_size": global_batch_size, + "model_name": model_name, + "dataset": "roneneldan/TinyStories-1M", + "seq_len": seq_len, + "lr": LR, + "weight_decay": 0.1, + }, + name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}" + if model_name == "unsloth/Llama-3.2-1B" + else f"tp{tp_size}_dp{dp_size}_cp{cp_size}", + ) + logger.info("Wandb initialized.") + # Log the current file to wandb + wandb.save("test_train.py") + + # Load model and tokenizer + logger.info(f"Loading model and tokenizer from {model_name}") + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}") + + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_mesh=tp_mesh if dist.is_initialized() else None, + tp_plan="auto", + torch_dtype=torch.bfloat16, + ) + logger.info(f"Model loaded onto device mesh: {tp_mesh}") + device = torch.device(f"cuda:{local_rank}") + logger.info(f"Using device: {device} for non-model tensors") + use_ddp = False + if dist.is_initialized() and dp_mesh.size() > 1: + model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD) + use_ddp = True + pass + + model.train() + + logger.info("Loading TinyStories dataset...") + raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing + + def tokenize_function(examples): + # Tokenize the text without padding + tokenized_batch = tokenizer( + examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None + ) + # Set labels to be the same as input_ids for Causal LM + tokenized_batch["labels"] = tokenized_batch["input_ids"].copy() + return tokenized_batch + + tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) + logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}") + + # Create packed sequences + def create_packed_sequences(examples): + # Flatten all sequences + all_tokens = [] + for input_ids in examples["input_ids"]: + all_tokens.extend(input_ids) + + # Split into sequences of seq_len + 1 (for input + label) + num_sequences = len(all_tokens) // (seq_len + 1) + packed_input_ids = [] + packed_labels = [] + + for i in range(num_sequences): + start_idx = i * (seq_len + 1) + end_idx = start_idx + (seq_len + 1) + # Get the full sequence + full_sequence = all_tokens[start_idx:end_idx] + # For input_ids, remove the last token + packed_input_ids.append(full_sequence[:-1]) + # For labels, remove the first token + packed_labels.append(full_sequence[1:]) + + return {"input_ids": packed_input_ids, "labels": packed_labels} + + # Apply packing to the dataset + packed_dataset = tokenized_dataset.map( + create_packed_sequences, + batched=True, + remove_columns=tokenized_dataset.column_names, + batch_size=1000, # Process in batches for efficiency + num_proc=60, + ) + logger.info(f"Dataset packed. New size: {len(packed_dataset)}") + + # Shuffle the packed dataset + packed_dataset = packed_dataset.shuffle(seed=42) + logger.info("Packed dataset shuffled") + + # Calculate local batch size + if dist.is_initialized(): + assert global_batch_size % dp_mesh.size() == 0, ( + f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})" + ) + local_batch_size = global_batch_size // dp_mesh.size() + else: + local_batch_size = global_batch_size + + logger.info( + f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}" + ) + + # Simple collate function since sequences are already packed + def collate_fn(batch): + input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long) + labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long) + return {"input_ids": input_ids, "labels": labels} + + if dist.is_initialized(): + sampler = DistributedSampler( + packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False + ) + else: + sampler = None + + dataloader = DataLoader( + packed_dataset, + batch_size=local_batch_size, + sampler=sampler, + shuffle=False, + collate_fn=collate_fn, + pin_memory=True, + ) + logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}") + + optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1) + + # Training loop + logger.info(f"Starting training for {num_train_steps} steps...") + model.train() + step = 0 + while step < num_train_steps: + for batch in dataloader: + if step >= num_train_steps: + break # Exit loop if max steps reached + + # Move batch to appropriate device + batch = {k: v.to(device) for k, v in batch.items()} + optimizer.zero_grad() + + # Add position_ids to batch before CP sharding + batch_size = batch["input_ids"].shape[0] + position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + batch["position_ids"] = position_ids + from torch.distributed.tensor.experimental._attention import _cp_options + + _cp_options.enable_load_balance = False + + with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation + cp_context = ( + nullcontext() + if cp_mesh.size() == 1 + else context_parallel( + cp_mesh, + buffers=[ + batch["input_ids"], + batch["labels"], + batch["position_ids"], + ], + buffer_seq_dims=[1, 1, 1], + ) + ) + with cp_context: + # Pop labels from batch before model forward pass + labels = batch.pop("labels") + outputs = model(**batch) # [mbs, seq_len/cp] + loss = outputs.loss + logits = outputs.logits + + # Compute loss with shifted labels + loss = model.loss_function( + logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size + ) + loss.backward() + + # all reduce grads across dp_cp if applicable + all_reduce_grads(model, world_mesh, use_ddp=use_ddp) + + if hasattr(model, "clip_grad_norm_"): + gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) # TODO: fix reported gradnorm + else: + # only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_ + assert len(list(model.parameters())) > 5, "No parameters found in model. Probably DDP bug.." + gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True) + + optimizer.step() + # allreduce loss across cp_dp before logging + if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1): + dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG) + current_loss = loss.item() + + # Log loss and gradnorm to wandb (only on rank 0 of dp group) + if not dist.is_initialized() or dist.get_rank() == 0: + logger.info( + f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}" + ) + wandb.log( + { + "train/loss": current_loss, + "train/gradnorm": gradnorm, + "step": step, + "lr": LR, + "GBS": global_batch_size, + } + ) + + step += 1 # Increment step count + + logger.info("Training loop finished.") + + # Save model using DCP (only if distributed) + if dist.is_initialized(): + state_dict = {"app": AppState(model, optimizer)} + dcp.save( + state_dict=state_dict, + checkpoint_id=CHECKPOINT_DIR, + ) + logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}") + else: + # Fallback to regular save for non-distributed case + save_dir = "test_model_nondist" + model.save_pretrained(save_dir, safe_serialization=False) + tokenizer.save_pretrained(save_dir) # Save tokenizer too + logger.info(f"Saved model to {save_dir}") + + dist.destroy_process_group() + logger.info("Cleaned up distributed process group") + # Finish wandb run on rank 0 + if dist.get_rank() == 0: + wandb.finish() + logger.info("Wandb run finished.") + + +def all_reduce_grads(model, world_mesh, use_ddp): + """All reduce gradients across dp_cp if applicable.""" + cp_mesh = world_mesh["cp"] + if use_ddp: + # DDP/FSDP takes care of syncing grads + mesh = cp_mesh + else: + mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + if dist.is_initialized() and mesh.size() > 1: + for name, param in model.named_parameters(): + if param.grad is not None: + # Workaround for cross-mesh communication limitation with DTensor gradients + if isinstance(param.grad, DTensor): + local_grad = param.grad.to_local() + # Ensure grad requires grad for inplace modification checks (might not be needed) + # local_grad = local_grad.detach().requires_grad_(True) + torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group()) + local_grad = local_grad / mesh.size() + # Assign averaged grad back - need careful handling if DTensor structure is complex + # This simple assignment might work if the grad structure matches param structure + param.grad = DTensor.from_local( + local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements + ) + else: + # Handle regular tensors if any exist (e.g. buffers not converted to DTensor) + torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group()) + + +class AppState(Stateful): + """Wrapper for checkpointing the Application State including model and optimizer.""" + + def __init__(self, model, optimizer=None): + self.model = model + self.optimizer = optimizer + + def state_dict(self): + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) + return {"model": model_state_dict, "optim": optimizer_state_dict} + + def load_state_dict(self, state_dict): + set_state_dict( + self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] + ) + + +def clip_grad_norm_( + parameters: Iterable[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: bool | None = None, +) -> torch.Tensor: + """ + Clip the gradient norm of an iterable of parameters. + """ + # Filter out parameters with no gradients + parameters = [p for p in parameters if p.grad is not None] + assert len(parameters) > 0, "No parameters with gradients found" + + # Calculate total norm + if norm_type == float("inf"): + total_norm = max(p.grad.detach().abs().max() for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type) + + # Convert DTensor to local tensor if needed + if isinstance(total_norm, DTensor): + total_norm = total_norm.full_tensor() + + # Clip gradients + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1: + for p in parameters: + p.grad.detach().mul_(clip_coef) + + return total_norm + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/3d_parallel_checks.py b/examples/pytorch/3d_parallel_checks.py new file mode 100644 index 00000000000..1c7e88e5e4e --- /dev/null +++ b/examples/pytorch/3d_parallel_checks.py @@ -0,0 +1,780 @@ +""": +This script is used to test training a model using Tensor Parallelism and Data Parallelism. + +Usage: +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_VISIBLE_DEVICES=4,5,6,7 +export CUDA_VISIBLE_DEVICES=5,6,7 +TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py +CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 test_train.py +CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 test_train.py + +TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 test_train.py +TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 test_train.py +TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py +IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=l +ocalhost:29504 test_train.py +""" + +import logging +import os +from contextlib import nullcontext +from typing import Dict, Iterable, Optional + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import torch.nn as nn +import torch.optim as optim +import wandb +from datasets import load_dataset +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.experimental import context_parallel +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils.data import DataLoader, default_collate +from torch.utils.data.distributed import DistributedSampler + +from transformers import AutoModelForCausalLM, AutoTokenizer + + +ignore_sanity_checks = int(os.environ.get("IGNORE_SANITY", 0)) == 1 +# torch.use_deterministic_algorithms(True) +torch.backends.cudnn.deterministic = True + +# Set up logging +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + +# from torch.distributed.tensor.experimental._attention import set_rotate_method + +# set_rotate_method("alltoall") # rotate shards using all-to-all + + +def main(): + tp_size = int(os.environ.get("TP_SIZE", 1)) + dp_size = int(os.environ.get("DP_SIZE", 4)) + cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration + sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP + # sdpa_backend = SDPBackend.MATH # For CP + global_batch_size = 8 # Desired global batch size + seq_len = 1024 # Sequence length + num_train_steps = 10000 # Number of training steps + LR = 1e-5 + model_name = "HuggingFaceTB/SmolLM2-1.7B" + # model_name = "unsloth/Llama-3.2-1B" + + CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}" + + # Initialize distributed environment + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + dist.init_process_group("nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + + assert world_size == tp_size * dp_size * cp_size, ( + f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})" + ) + + mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size) + world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp")) + tp_mesh = world_mesh["tp"] + dp_mesh = world_mesh["dp"] + cp_mesh = world_mesh["cp"] + world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + logger.info(f"Created DeviceMesh: {world_mesh}") + logger.info( + f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}" + ) + + if dist.get_rank() == 0: + wandb.init( + project="tp_dp_test", + config={ + "tp_size": tp_size, + "dp_size": dp_size, + "cp_size": cp_size, + "global_batch_size": global_batch_size, + "model_name": model_name, + "dataset": "roneneldan/TinyStories-1M", + "seq_len": seq_len, + "lr": LR, + "weight_decay": 0.1, + }, + name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}" + if model_name == "unsloth/Llama-3.2-1B" + else f"tp{tp_size}_dp{dp_size}_cp{cp_size}", + ) + logger.info(f"ignore_sanity_checks is set to: {ignore_sanity_checks}") + logger.info("Wandb initialized.") + # Log the current file to wandb + wandb.save("test_train.py") + + else: + logger.info("Running in non-distributed mode. DeviceMesh not applicable.") + rank = 0 + world_size = 1 + local_rank = 0 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + wandb.init( + project="tp_dp_test", + config={ + "tp_size": 1, + "dp_size": 1, + "global_batch_size": global_batch_size, + "model_name": model_name, + "dataset": "roneneldan/TinyStories-1M", + "seq_len": seq_len, + }, + name="llama_tp1_dp1_nondist" if model_name == "unsloth/Llama-3.2-1B" else "tp1_dp1_nondist", + ) + logger.info("Wandb initialized for non-distributed run.") + + # Load model and tokenizer + logger.info(f"Loading model and tokenizer from {model_name}") + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}") + + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_mesh=tp_mesh if dist.is_initialized() else None, + tp_plan="auto", + torch_dtype=torch.bfloat16, + ) + logger.info(f"Model loaded onto device mesh: {tp_mesh}") + + if dist.is_initialized(): + assert model.config.num_key_value_heads % tp_mesh.size() == 0, ( + f"num_key_value_heads={model.config.num_key_value_heads} must be divisible by tp_size={tp_mesh.size()}" + ) + device = torch.device(f"cuda:{local_rank}") + else: + model = model.to(device) + + logger.info(f"Using device: {device} for non-model tensors") + use_ddp = False + if dist.is_initialized() and dp_mesh.size() > 1: + # FSDP1 + model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD) + # FSDP2 + # for transformer_block in model.model.layers: + # fully_shard(transformer_block, mesh=dp_mesh, reshard_after_forward=False) + # fully_shard(model.model, mesh=dp_mesh, reshard_after_forward=False) + # DDP + # replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + # assert len(list(model.parameters()))>5, "No parameters found in model. Probably DDP/FSDP bug.." # TODO: we should be cautious abt using model.parameters() + use_ddp = True + + model.train() + assert len(list(model.parameters())) > 0, "No parameters found in model. Probably DDP bug.." + assert len([p for p in model.parameters() if p.requires_grad]) > 0, ( + "No gradients found in model. Probably DDP bug.." + ) + + if dist.is_initialized() and not ignore_sanity_checks: + # assert model is replicated across all dp + for name, param in model.named_parameters(): + sanity_check_tensor_sync(param, dp_mesh) + + # assert model is different across tp (only for sharded params) + for name, param in model.named_parameters(): + if isinstance(param, DTensor) and param.placements[0].is_shard(): + # Only check sharded parameters for non-sync across TP + sanity_check_tensor_sync(param, tp_mesh, not_sync=True) + elif isinstance(param, DTensor) and param.placements[0].is_replicate(): + # Replicated parameters should be the same across TP + sanity_check_tensor_sync(param, tp_mesh) + + # assert model is replicated across cp + for name, param in model.named_parameters(): + sanity_check_tensor_sync(param, cp_mesh) + + # Load and preprocess TinyStories dataset + logger.info("Loading TinyStories dataset...") + raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing + + def tokenize_function(examples): + # Tokenize the text without padding + tokenized_batch = tokenizer( + examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None + ) + # Set labels to be the same as input_ids for Causal LM + tokenized_batch["labels"] = tokenized_batch["input_ids"].copy() + return tokenized_batch + + tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) + logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}") + + # Create packed sequences + def create_packed_sequences(examples): + # Flatten all sequences + all_tokens = [] + for input_ids in examples["input_ids"]: + all_tokens.extend(input_ids) + + # Split into sequences of seq_len + 1 (for input + label) + num_sequences = len(all_tokens) // (seq_len + 1) + packed_input_ids = [] + packed_labels = [] + + for i in range(num_sequences): + start_idx = i * (seq_len + 1) + end_idx = start_idx + (seq_len + 1) + # Get the full sequence + full_sequence = all_tokens[start_idx:end_idx] + # For input_ids, remove the last token + packed_input_ids.append(full_sequence[:-1]) + # For labels, remove the first token + packed_labels.append(full_sequence[1:]) + + return {"input_ids": packed_input_ids, "labels": packed_labels} + + # Apply packing to the dataset + packed_dataset = tokenized_dataset.map( + create_packed_sequences, + batched=True, + remove_columns=tokenized_dataset.column_names, + batch_size=1000, # Process in batches for efficiency + num_proc=60, + ) + logger.info(f"Dataset packed. New size: {len(packed_dataset)}") + + # Shuffle the packed dataset + packed_dataset = packed_dataset.shuffle(seed=42) + logger.info("Packed dataset shuffled") + + # Calculate local batch size + if dist.is_initialized(): + assert global_batch_size % dp_mesh.size() == 0, ( + f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})" + ) + local_batch_size = global_batch_size // dp_mesh.size() + else: + local_batch_size = global_batch_size + + logger.info( + f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}" + ) + + # Simple collate function since sequences are already packed + def collate_fn(batch): + input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long) + labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long) + return {"input_ids": input_ids, "labels": labels} + + if dist.is_initialized(): + sampler = DistributedSampler( + packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False + ) + else: + sampler = None + + dataloader = DataLoader( + packed_dataset, + batch_size=local_batch_size, + sampler=sampler, + shuffle=False, + collate_fn=collate_fn, + ) + logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}") + + optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1) + + # Training loop + logger.info(f"Starting training for {num_train_steps} steps...") + model.train() + step = 0 + while step < num_train_steps: + for batch in dataloader: + if step >= num_train_steps: + break # Exit loop if max steps reached + + # Move batch to appropriate device + batch = {k: v.to(device) for k, v in batch.items()} + + # Sanity checks for batch distribution (only if distributed) + if dist.is_initialized() and not ignore_sanity_checks: + # check batch is same across all tp + sanity_check_tensor_sync(batch["input_ids"], tp_mesh) + # check batch is different across dp + sanity_check_tensor_sync(batch["input_ids"], dp_mesh, not_sync=True) + + optimizer.zero_grad() + + # Add position_ids to batch before CP sharding + batch_size = batch["input_ids"].shape[0] + position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + batch["position_ids"] = position_ids + from torch.distributed.tensor.experimental._attention import _cp_options + + _cp_options.enable_load_balance = False + + with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation + cp_context = ( + nullcontext() + if cp_mesh.size() == 1 + else context_parallel( + cp_mesh, + buffers=[ + batch["input_ids"], + batch["labels"], + batch["position_ids"], + ], # TODO: need to add attention mask + buffer_seq_dims=[1, 1, 1], + ) + ) + with cp_context: + # Pop labels from batch before model forward pass + labels = batch.pop("labels") + outputs = model(**batch) # [mbs, seq_len/cp] + loss = outputs.loss + logits = outputs.logits + + # Compute loss with shifted labels + loss = model.loss_function( + logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size + ) + + # Sanity checks for logits + if dist.is_initialized() and not ignore_sanity_checks: + # sanity_check_tensor_sync(logits, tp_mesh) # TODO: only true without sequence parallel + sanity_check_tensor_sync(logits, dp_mesh, not_sync=True) + sanity_check_tensor_sync(logits, cp_mesh, not_sync=True) + + loss.backward() + + # all reduce grads across dp_cp if applicable + all_reduce_grads(model, world_mesh, use_ddp=use_ddp) + + # Sanity checks for gradients (only if distributed) + if dist.is_initialized() and not ignore_sanity_checks: + # check grads are not same across all tp (for sharded grads) + for name, param in model.named_parameters(): + if param.grad is not None and isinstance(param.grad, DTensor): + if param.grad.placements[0].is_shard(): + sanity_check_tensor_sync(param.grad, tp_mesh, not_sync=True) + elif param.grad.placements[0].is_replicate(): + sanity_check_tensor_sync(param.grad, tp_mesh) + # check grads are same across dp + for name, param in model.named_parameters(): + if param.grad is not None and dp_mesh.size() > 1: + sanity_check_tensor_sync(param.grad, dp_mesh) + # check grads are same across cp + for name, param in model.named_parameters(): + if param.grad is not None and cp_mesh.size() > 1: + sanity_check_tensor_sync(param.grad, cp_mesh) + + # Calculate gradient norm and clip gradients + if hasattr(model, "clip_grad_norm_"): + # when using FSDP or DDP, model.parameters() doesn't work + gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) + else: + assert len(list(model.parameters())) > 2, "No parameters found in model. Probably DDP bug.." + assert len([p for p in model.parameters() if p.requires_grad]) > 2, ( + "No gradients found in model. Probably DDP bug.." + ) + assert len([p for p in model.parameters() if p.grad is not None]) > 2, ( + "No gradients found in model. Probably DDP bug.." + ) + # only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_ + gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True) + + optimizer.step() + # Sanity checks for updated model parameters (only if distributed) + if dist.is_initialized() and not ignore_sanity_checks: + # check updated model is different across all tp (for sharded params) + for name, param in model.named_parameters(): + if isinstance(param, DTensor): + if param.placements[0].is_shard(): + sanity_check_tensor_sync(param, tp_mesh, not_sync=True) + elif param.placements[0].is_replicate(): + sanity_check_tensor_sync(param, tp_mesh) + # check updated model is same across dp + for name, param in model.named_parameters(): + sanity_check_tensor_sync(param, dp_mesh) + # check updated model is same across cp + for name, param in model.named_parameters(): + sanity_check_tensor_sync(param, cp_mesh) + + # allreduce loss across cp_dp before logging + if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1): + dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG) + current_loss = loss.item() + + # Log loss and gradnorm to wandb (only on rank 0 of dp group) + if not dist.is_initialized() or dist.get_rank() == 0: + logger.info( + f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}" + ) + wandb.log( + { + "train/loss": current_loss, + "train/gradnorm": gradnorm, + "step": step, + "lr": LR, + "GBS": global_batch_size, + } + ) + + step += 1 # Increment step count + + logger.info("Training loop finished.") + + # Save model using DCP (only if distributed) + if dist.is_initialized(): + state_dict = {"app": AppState(model, optimizer)} + dcp.save( + state_dict=state_dict, + checkpoint_id=CHECKPOINT_DIR, + ) + logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}") + else: + # Fallback to regular save for non-distributed case + save_dir = "test_model_nondist" + model.save_pretrained(save_dir, safe_serialization=False) + tokenizer.save_pretrained(save_dir) # Save tokenizer too + logger.info(f"Saved model to {save_dir}") + + # Example of loading the checkpoint (only if distributed) + if dist.is_initialized(): + # Create a new model instance + logger.info("Creating new model instance for verification") + new_model = AutoModelForCausalLM.from_pretrained( + model_name, + device_mesh=tp_mesh, + torch_dtype=torch.bfloat16, # Use same dtype + ) + new_optimizer = optim.AdamW(new_model.parameters(), lr=LR) + + # Load checkpoint into new model + state_dict = {"app": AppState(new_model, new_optimizer)} + dcp.load( + state_dict=state_dict, + checkpoint_id=CHECKPOINT_DIR, + ) + logger.info("Loaded checkpoint into new model") + + # Verify model weights match + logger.info("Verifying model weights match...") + for (name1, param1), (name2, param2) in zip(model.named_parameters(), new_model.named_parameters()): + torch.testing.assert_close( + param1.to_local(), + param2.to_local(), + rtol=1e-3, + atol=1e-3, + msg=f"Weights mismatch in {name1} vs {name2}", + ) + + # Verify optimizer states match + logger.info("Verifying optimizer states match...") + for name1, state1 in optimizer.state_dict().items(): + state2 = new_optimizer.state_dict()[name1] + if name1 == "state": + # Compare state dictionaries for each parameter + for param_id, param_state1 in state1.items(): + param_state2 = state2[param_id] + # Compare each state component (step, exp_avg, exp_avg_sq) + for key, value1 in param_state1.items(): + value2 = param_state2[key] + if isinstance(value1, DTensor): + # Convert DTensors to local tensors for comparison + torch.testing.assert_close( + value1.to_local(), + value2.to_local(), + rtol=1e-5, + atol=1e-5, + msg=f"Optimizer state mismatch in state[{param_id}][{key}]", + ) + else: + torch.testing.assert_close( + value1, + value2, + rtol=1e-5, + atol=1e-5, + msg=f"Optimizer state mismatch in state[{param_id}][{key}]", + ) + elif name1 == "param_groups": + # Compare param_groups (excluding the actual params list) + for i, (group1, group2) in enumerate(zip(state1, state2)): + for key in group1: + if key != "params": # Skip comparing the params list + assert group1[key] == group2[key], f"Param group mismatch in param_groups[{i}][{key}]" + + # Run a forward pass with both models to verify outputs match + logger.info("Running forward pass verification...") + with torch.no_grad(): + # Use the last batch for verification + batch = {k: v.to(device) for k, v in batch.items()} # Ensure batch is on correct device + original_outputs = model(**batch) + new_outputs = new_model(**batch) + torch.testing.assert_close( + original_outputs.logits.to_local(), + new_outputs.logits.to_local(), + rtol=1e-3, + atol=1e-3, + msg="Model outputs do not match!", + ) # Increased tolerance slightly for bf16 + + # Clean up distributed environment and finish wandb run + if dist.is_initialized(): + dist.destroy_process_group() + logger.info("Cleaned up distributed process group") + # Finish wandb run on rank 0 + if dist.get_rank() == 0: + wandb.finish() + logger.info("Wandb run finished.") + else: + wandb.finish() + logger.info("Wandb run finished.") + + +def all_reduce_grads(model, world_mesh, use_ddp): + """All reduce gradients across dp_cp if applicable.""" + cp_mesh = world_mesh["cp"] + if use_ddp: + # DDP takes care of syncing grads + mesh = cp_mesh + else: + mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + if dist.is_initialized() and mesh.size() > 1: + for name, param in model.named_parameters(): + if param.grad is not None: + # Workaround for cross-mesh communication limitation with DTensor gradients + if isinstance(param.grad, DTensor): + local_grad = param.grad.to_local() + # Ensure grad requires grad for inplace modification checks (might not be needed) + # local_grad = local_grad.detach().requires_grad_(True) + torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group()) + local_grad = local_grad / mesh.size() + # Assign averaged grad back - need careful handling if DTensor structure is complex + # This simple assignment might work if the grad structure matches param structure + param.grad = DTensor.from_local( + local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements + ) + else: + # Handle regular tensors if any exist (e.g. buffers not converted to DTensor) + torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group()) + + +class ContextParallelCollator: + """Collator for context parallel training that splits sequences into chunks.""" + + def __init__(self, cp_mesh: Optional[DeviceMesh] = None): + self.cp_mesh = cp_mesh + + def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + batch = default_collate(batch) + if self.cp_mesh is not None and self.cp_mesh.size() > 1: + # Get sequence length from the input batch + seq_len = batch["input_ids"].shape[1] + assert seq_len % self.cp_mesh.size() == 0, ( + f"Sequence length {seq_len} must be divisible by CP size {self.cp_mesh.size()}" + ) + chunk_size = seq_len // self.cp_mesh.size() + cp_rank = self.cp_mesh.get_local_rank() + start_idx = cp_rank * chunk_size + end_idx = start_idx + chunk_size + + # Keep only the local chunk of the sequence + batch["input_ids"] = batch["input_ids"][:, start_idx:end_idx] + batch["attention_mask"] = batch["attention_mask"][:, start_idx:end_idx] + batch["labels"] = batch["labels"][:, start_idx:end_idx] + + return batch + + +class AppState(Stateful): + """Wrapper for checkpointing the Application State including model and optimizer.""" + + def __init__(self, model, optimizer=None): + self.model = model + self.optimizer = optimizer + + def state_dict(self): + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) + return {"model": model_state_dict, "optim": optimizer_state_dict} + + def load_state_dict(self, state_dict): + set_state_dict( + self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] + ) + + +def sanity_check_tensor_sync( + tensor: torch.Tensor, mesh: DeviceMesh, rtol: float = 1e-4, atol: float = 1e-4, not_sync: bool = False +) -> None: + """ + Verify that a tensor is synchronized (or not synchronized) across all processes in the mesh's process group. + Handles both regular tensors and DTensors. + + Args: + tensor (torch.Tensor): The tensor to check for synchronization (can be DTensor) + mesh (DeviceMesh): The device mesh containing the process group + rtol (float): Relative tolerance for comparison + atol (float): Absolute tolerance for comparison + not_sync (bool): If True, asserts that tensors are NOT synchronized. If False, asserts they are synchronized. + """ + if not dist.is_initialized() or mesh.size() == 1: + return # No need to check in non-distributed mode + + # Get the process group from the mesh + pg = mesh.get_group() + + # Convert DTensor to local tensor if needed + if hasattr(tensor, "to_local"): + local_tensor = tensor.to_local() + else: + local_tensor = tensor + + # Gather tensors from all processes + world_size = dist.get_world_size(pg) + gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)] + dist.all_gather(gathered_tensors, local_tensor, group=pg) + + # Compare each tensor with the first one + for i in range(1, world_size): + try: + torch.testing.assert_close(gathered_tensors[0], gathered_tensors[i], rtol=rtol, atol=atol) + except AssertionError as e: + if not_sync: + continue + # # Add detailed debugging for logit synchronization issues + # print(f"\nLogit synchronization error between rank 0 and rank {i}:") + # print(f"Tensor shape: {gathered_tensors[0].shape}") + # print(f"Number of mismatched elements: {(gathered_tensors[0] != gathered_tensors[i]).sum()}") + # print(f"Percentage of mismatched elements: {((gathered_tensors[0] != gathered_tensors[i]).sum() / gathered_tensors[0].numel() * 100):.2f}%") + + # # Find the first few mismatches + # mismatches = torch.nonzero(gathered_tensors[0] != gathered_tensors[i]) + # print("\nFirst few mismatches:") + # for idx in mismatches[:5]: + # idx = tuple(idx.tolist()) + # print(f"Index {idx}:") + # print(f"Rank 0 value: {gathered_tensors[0][idx]}") + # print(f"Rank {i} value: {gathered_tensors[i][idx]}") + # print(f"Absolute difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx])}") + # print(f"Relative difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx]) / max(abs(gathered_tensors[0][idx]), abs(gathered_tensors[i][idx]))}") + + # # Check if differences are systematic (e.g., all positive or negative) + # diff = gathered_tensors[0] - gathered_tensors[i] + # print(f"\nDifference statistics:") + # print(f"Mean difference: {diff.mean()}") + # print(f"Std difference: {diff.std()}") + # print(f"Max positive difference: {diff.max()}") + # print(f"Max negative difference: {diff.min()}") + raise e + + +def clip_grad_norm_( + parameters: Iterable[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: bool | None = None, +) -> torch.Tensor: + """ + Clip the gradient norm of an iterable of parameters. + """ + # Filter out parameters with no gradients + parameters = [p for p in parameters if p.grad is not None] + assert len(parameters) > 0, "No parameters with gradients found" + + # Calculate total norm + if norm_type == float("inf"): + total_norm = max(p.grad.detach().abs().max() for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type) + + # Convert DTensor to local tensor if needed + if isinstance(total_norm, DTensor): + total_norm = total_norm.full_tensor() + + # Clip gradients + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1: + for p in parameters: + p.grad.detach().mul_(clip_coef) + + return total_norm + + +def check_params_sync(model_params, original_params): + """ + Check if original_params are being updated in sync with model parameters. + + Args: + model_params: Iterator of model parameters after update + original_params: List of original parameters before DDP wrapping + """ + for mp, op in zip(model_params, original_params): + if isinstance(mp, DTensor): + mp = mp.to_local() + if isinstance(op, DTensor): + op = op.to_local() + if not torch.allclose(mp.data, op.data, rtol=0, atol=0): + raise RuntimeError(f"Parameters out of sync: model param {mp.data} != original param {op.data}") + return True + + +def get_parameters(model: nn.Module) -> Iterable[torch.Tensor]: + """ + Get all parameters from a model by iterating over its modules. + This is an alternative to model.parameters() that works with DTensor models. + + Args: + model (nn.Module): The model to get parameters from + + Returns: + Iterable[torch.Tensor]: An iterator over all parameters in the model + """ + for name, module in model._modules.items(): + # Look for parameters in module attributes + for attr_name, attr in module.__dict__.items(): + if isinstance(attr, torch.Tensor) and attr.requires_grad: + yield attr + # Recursively get parameters from submodules + for param in get_parameters(module): + yield param + + +def update_model_parameters(model: nn.Module) -> None: + """ + Update model._parameters using named_modules() to ensure all parameters are properly tracked. + + Args: + model (nn.Module): The model to update parameters for + """ + # Clear existing parameters + model._parameters = {} + + # Add parameters from named_modules + for name, module in model.named_modules(): + # Skip the root module itself + if name == "": + continue + + # Get the parameter name by removing 'module.' prefix if it exists + param_name = name.replace("module.", "") + + # Add weight and bias parameters if they exist + if hasattr(module, "weight") and module.weight is not None: + model._parameters[f"{param_name}.weight"] = module.weight + if hasattr(module, "bias") and module.bias is not None: + model._parameters[f"{param_name}.bias"] = module.bias + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/context_parallel.py b/examples/pytorch/context_parallel.py new file mode 100644 index 00000000000..22cc75b20f0 --- /dev/null +++ b/examples/pytorch/context_parallel.py @@ -0,0 +1,94 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor.experimental import context_parallel +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.nn.parallel import DistributedDataParallel as DDP + +from transformers import AutoModelForCausalLM +from transformers.loss.loss_utils import ForCausalLMLoss + + +world_size = int(os.environ.get("WORLD_SIZE", "1")) +cp_mesh = init_device_mesh("cuda", (world_size,)) +rank = torch.distributed.get_node_local_rank() + +device = "cuda" +dtype = torch.bfloat16 +sdpa_backend = SDPBackend.FLASH_ATTENTION + +# prepare inputs +batch_size = 1 +seq_len = 128 + +input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device) + +ignore_index = -100 +# When using CP, we need to use `shift_labels` +shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index) +shift_labels = shift_labels[..., 1:].contiguous() + +position_ids = ( + torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1 +) + +# sync input as they are created randomly +dist.broadcast(input_ids, src=0) +dist.broadcast(shift_labels, src=0) +dist.broadcast(position_ids, src=0) + +# model and optimizer +repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct" +model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + +model.train() +model.zero_grad() +optimizer.zero_grad() + +# For loss +vocab_size = model.config.vocab_size + +# so training could be synced +model = DDP(model, device_ids=[rank]) + +# prepare for CP +buffers = (input_ids, shift_labels, position_ids) +buffer_seq_dims = (1, 1, 1) +# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`. +# no_restore_buffers = set(buffers) +no_restore_buffers = None + +# run with CP +with sdpa_kernel(sdpa_backend): + with context_parallel( + cp_mesh, + buffers=buffers, + buffer_seq_dims=buffer_seq_dims, + no_restore_buffers=no_restore_buffers, + ): + outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids) + print(outputs.logits.shape) + + # So far we need to compute `loss` outside `model.forward` when using `shift_labels` + # loss = outputs.loss + loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size) + + # This could be outside `context_parallel` context if `no_restore_buffers` is specified + loss.backward() + optimizer.step() diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 8d03c5cf790..1b87a554d3c 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -142,7 +142,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["tensor_parallel"] = [ "shard_and_distribute_module", - "SUPPORTED_TP_STYLES", + "ALL_PARALLEL_STYLES", "translate_to_torch_parallel_style", ] try: @@ -271,7 +271,7 @@ if TYPE_CHECKING: pass else: from .tensor_parallel import ( - SUPPORTED_TP_STYLES, + ALL_PARALLEL_STYLES, shard_and_distribute_module, translate_to_torch_parallel_style, ) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index e788321b49b..ef868148b87 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -13,11 +13,15 @@ # limitations under the License. from __future__ import annotations +import operator +import os import re -from functools import lru_cache, partial -from typing import List, Optional, Tuple, Union +from collections.abc import MutableMapping +from functools import partial, reduce +from typing import Callable, List, Optional, Tuple, Union import torch +import torch.distributed as dist from torch import nn from ..utils import is_torch_greater_or_equal, logging @@ -35,6 +39,56 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available: from torch.distributed.tensor import DTensor, Placement, Replicate, Shard +def initialize_tensor_parallelism(tp_plan, tp_size=None): + r""" + Sets up the device mesh and initilized the backend for tensor parallelism. + This function is called when the model is loaded and the TP plan is set to 'auto'. + """ + if tp_plan is None: + return None, None, None + + if not is_torch_greater_or_equal("2.5"): + raise EnvironmentError("Tensor parallel is only supported for `torch>=2.5`.") + + # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. + device_type = torch._C._get_accelerator().type + if not torch.distributed.is_initialized(): + try: + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "ccl", "hpu": "hccl"} + backend = backend_map.get(device_type) + if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", 0)): + backend = "ccl" + + torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size) + current_device = getattr(torch, device_type) + if device_type != "cpu": + current_device.set_device(local_rank) + + except Exception as e: + raise EnvironmentError( + "We tried to initialize torch.distributed for you, but it failed. Make " + "sure you init torch distributed in your script to use `tp_plan='auto'`." + ) from e + index = current_device.current_device() if device_type != "cpu" else None + tp_device = torch.device(device_type, index) + + # Silence output for non-primary ranks + if index is not None and index > 0: + import sys + + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + + device_map = tp_device + tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size() + device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,)) + return tp_device, device_map, device_mesh + + def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: """ Convert block count or proportions to block sizes. @@ -220,18 +274,38 @@ def repack_weights( def get_tensor_shard(param, empty_param, device_mesh, rank, dim): - if dim == 0: - size_ = empty_param.shape[0] - param = param[rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), ...] - elif dim == 1 or dim == -2: - size_ = empty_param.shape[-2] - param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), :] - elif dim == 2 or dim == -1: - size_ = empty_param.shape[-1] - param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size())] - else: - raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported") - return param + """ + Generalized tensor sharding across a multi-dimensional device mesh. + + Args: + param (torch.Tensor): The tensor to shard. + empty_param (torch.Tensor): A tensor used for shape reference. + device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh. + rank (int): Global rank of the current process/device. + dim (int): Dimension along which to shard the tensor. + """ + param_dim = empty_param.dim() + if dim < 0: + dim = param_dim + dim + if dim >= param_dim: + raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}") + + # Flatten the mesh to get the total number of devices + mesh_shape = device_mesh.shape + world_size = reduce(operator.mul, mesh_shape) + + if rank >= world_size: + raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}") + + shard_size = empty_param.shape[dim] // world_size + start = rank * shard_size + end = start + shard_size + + # Construct slicing index dynamically + slice_indices = [slice(None)] * param_dim + slice_indices[dim] = slice(start, end) + + return param[tuple(slice_indices)] def distribute_module( @@ -339,6 +413,41 @@ class IsolatedParallel(TensorParallelLayer): ) +class ReplicateParallel(TensorParallelLayer): + """ + This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example) + """ + + def __init__(self, *, use_dtensor=True, use_local_output=True): + super().__init__() + self.input_layouts = (Replicate(),) + self.output_layouts = (Replicate(),) + self.desired_input_layouts = (Replicate(),) + self.use_local_output = use_local_output + self.use_dtensor = use_dtensor + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + # TODO: figure out dynamo support for instance method and switch this to instance method + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + + return input_tensor + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + return outputs.to_local() if use_local_output else outputs + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + param = param[...].to(param_casting_dtype) + if to_contiguous: + param = param.contiguous() + param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) + return param + + class ColwiseParallel(TensorParallelLayer): """ General tensor parallel layer for transformers. @@ -611,52 +720,62 @@ class SequenceParallel(TensorParallelLayer): return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) -SUPPORTED_TP_STYLES = { - "colwise", - "rowwise", - "colwise_rep", - "rowwise_rep", - "local_colwise", - "local_rowwise", - "local", - "gather", - "local_packed_rowwise", - "sequence_parallel", -} - - -@lru_cache -def translate_to_torch_parallel_style(style: str): +class ParallelInterface(MutableMapping): """ - In model configurations, we use a neutral type (string) to specify parallel - styles, here we translate them into torch.distributed tensor-parallel - types. + Dict-like object keeping track of allowed attention functions. You can easily add a new attention function + with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`, + it needs to declare a new instance of this class inside the `modeling_.py`, and declare it on that instance. """ - if not isinstance(style, str): - raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") - if style == "colwise": - return ColwiseParallel() - elif style == "rowwise": - return RowwiseParallel() - elif style == "colwise_rep": - return ColwiseParallel(output_layouts=Replicate()) - elif style == "rowwise_rep": - return RowwiseParallel(input_layouts=Replicate()) - elif style == "local_colwise": - return ColwiseParallel(use_dtensor=False) - elif style == "local_rowwise": - return RowwiseParallel(use_dtensor=False) - elif style == "local": - return IsolatedParallel() - elif style == "gather": - return GatherParallel() - elif style == "local_packed_rowwise": - return PackedRowwiseParallel(use_dtensor=False) - elif style == "sequence_parallel": - return SequenceParallel() - else: - raise ValueError(f"Unsupported parallel style value: {style}") + # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if + # a new instance is created (in order to locally override a given function) + _global_mapping = { + "colwise": ColwiseParallel(), + "rowwise": RowwiseParallel(), + "colwise_rep": ColwiseParallel(output_layouts=Replicate()), + "rowwise_rep": RowwiseParallel(input_layouts=Replicate()), + "local_colwise": ColwiseParallel(use_dtensor=False), + "local_rowwise": RowwiseParallel(use_dtensor=False), + "local": IsolatedParallel(), + "gather": GatherParallel(), + "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False), + "sequence_parallel": SequenceParallel(), + "replicate": ReplicateParallel(), + } + + def __init__(self): + self._local_mapping = {} + + def __getitem__(self, key): + # First check if instance has a local override + if key in self._local_mapping: + return self._local_mapping[key] + return self._global_mapping[key] + + def __setitem__(self, key, value): + # Allow local update of the default functions without impacting other instances + self._local_mapping.update({key: value}) + + def __delitem__(self, key): + del self._local_mapping[key] + + def __iter__(self): + # Ensure we use all keys, with the overwritten ones on top + return iter({**self._global_mapping, **self._local_mapping}) + + def __len__(self): + return len(self._global_mapping.keys() | self._local_mapping.keys()) + + @classmethod + def register(cls, key: str, value: Callable): + cls._global_mapping.update({key: value}) + + def valid_keys(self) -> List[str]: + return list(self.keys()) + + +# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones +ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface() def convert_local_tensor_to_dtensor( @@ -722,13 +841,15 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr # 1. We add hooks to the layer being loaded: if current_module_plan is not None: - tp_layer = translate_to_torch_parallel_style(current_module_plan) + tp_layer = ALL_PARALLEL_STYLES[current_module_plan] try: tp_layer.prepare_module_tp(module, device_mesh) except NotImplementedError as e: print( f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}" ) + module._hf_tp_plan = current_module_plan + module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}" # 2. We add hooks to the parent module if needed if "." in layer_name: @@ -736,9 +857,11 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr generic_name = re.sub(r"\d+", "*", parent_layer_name) # The module itself needs hooks if module_plan := tp_plan.get(generic_name, False): - tp_layer = translate_to_torch_parallel_style(module_plan) + tp_layer = ALL_PARALLEL_STYLES[module_plan] module_to_tp_ = model.get_submodule(parent_layer_name) tp_layer.prepare_module_tp(module_to_tp_, device_mesh) + module_to_tp_._hf_tp_plan = current_module_plan + module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}" def shard_and_distribute_module( @@ -760,28 +883,29 @@ def shard_and_distribute_module( current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan) + if current_module_plan is None: + current_module_plan = "replicate" + if dist.get_rank() == 0: + logger.info(f"Tensor parallel plan for {param_name} not found, using default 'replicate' plan.") + else: + if dist.get_rank() == 0: + logger.info(f"Tensor parallel plan for {param_name}: {current_module_plan}") + # Add hooks to the module if not done yet # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) if not getattr(module_to_tp, "_is_hooked", False): add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) module_to_tp._is_hooked = True - if current_module_plan is not None: - try: - tp_layer = translate_to_torch_parallel_style(current_module_plan) - param = tp_layer.partition_tensor( - param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh - ) - except NotImplementedError as e: - print( - f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}" - ) - else: - # TODO log no plan modules in set - # print("No plan for", parameter_name,end ="\n") - param = param[...].to(param_casting_dtype) - if is_contiguous: - param = param.contiguous() + try: + tp_layer = ALL_PARALLEL_STYLES[current_module_plan] + param = tp_layer.partition_tensor( + param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh + ) + except NotImplementedError as e: + print( + f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}" + ) # SUPER IMPORTANT we have to use setattr # otherwise loading is crazy slow diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b40b7cd2b38..ece787c7b67 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -62,8 +62,9 @@ from .integrations.flash_attention import flash_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.tensor_parallel import ( - SUPPORTED_TP_STYLES, + ALL_PARALLEL_STYLES, _get_parameter_tp_plan, + initialize_tensor_parallelism, repack_weights, replace_state_dict_local_with_dtensor, shard_and_distribute_module, @@ -797,7 +798,7 @@ def _load_state_dict_into_meta_model( param_name, casting_dtype, to_contiguous, - int(os.environ["RANK"]), # the rank + device_mesh.get_local_rank(), device_mesh, ) else: @@ -1964,9 +1965,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if self._tp_plan is not None and is_torch_greater_or_equal("2.3"): for _, v in self._tp_plan.items(): - if v not in SUPPORTED_TP_STYLES: + if v not in ALL_PARALLEL_STYLES: raise ValueError( - f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}" + f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" ) def dequantize(self): @@ -3559,6 +3560,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh) if safe_serialization: + # TODO: fix safe_serialization for tied weights # Safetensors does not allow tensor aliasing. # We're going to remove aliases before saving ptrs = collections.defaultdict(list) @@ -4040,6 +4042,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi `torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations. tp_size (`str`, *optional*): A torch tensor parallel degree. If not provided would default to world size. + device_mesh (`torch.distributed.DeviceMesh`, *optional*): + A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now. offload_folder (`str` or `os.PathLike`, *optional*): If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_state_dict (`bool`, *optional*): @@ -4137,6 +4141,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi gguf_file = kwargs.pop("gguf_file", None) tp_plan = kwargs.pop("tp_plan", None) tp_size = kwargs.pop("tp_size", None) + device_mesh = kwargs.pop("device_mesh", None) trust_remote_code = kwargs.pop("trust_remote_code", None) # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model @@ -4172,59 +4177,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple # `device_map` pointing to the correct device - device_mesh = None - if tp_plan is not None: - if not is_torch_greater_or_equal("2.5"): - raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") - - # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. - device_type = torch._C._get_accelerator().type - - if not torch.distributed.is_initialized(): - try: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - if device_type == "cuda": - torch.distributed.init_process_group( - "nccl", rank=rank, world_size=world_size, init_method="env://" - ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - elif device_type == "cpu": - cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo" - torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size) - elif device_type == "xpu": - torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size) - torch.xpu.set_device(int(os.environ["LOCAL_RANK"])) - elif device_type == "hpu": - torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size) - torch.hpu.set_device(int(os.environ["LOCAL_RANK"])) - - except Exception as e: - raise EnvironmentError( - "We tried to initialize torch.distributed for you, but it failed, make" - "sure you init torch distributed in your script to use `tp_plan='auto'`" - ) from e - - # Get device with index assuming equal number of devices per host - if device_type == "xpu": - index = torch.xpu.current_device() - elif device_type == "hpu": - index = torch.hpu.current_device() - else: - index = None if device_type == "cpu" else torch.cuda.current_device() - tp_device = torch.device(device_type, index) - - if index is not None and index > 0: - import sys - - sys.stdout = open(os.devnull, "w") - sys.stderr = open(os.devnull, "w") - # This is the easiest way to dispatch to the current process device - device_map = tp_device - - # Assuming sharding the model onto the world when tp_size not provided - tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size() - device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,)) + if device_mesh is None: + tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) + else: + # TODO: make device_mesh support multiple dimensions + if device_mesh.ndim == 1: + raise ValueError("device_mesh must be 1 dimensional and will be used for TP") + device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"])) if use_auth_token is not None: warnings.warn( @@ -5142,7 +5101,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi name, casting_dtype, to_contiguous, - os.environ["RANK"], + device_mesh.get_local_rank(), device_mesh, )