import logging import os import time from argparse import Namespace from pathlib import Path import datasets import torch from datasets import load_dataset from torch.utils.data import IterableDataset from torch.utils.data.dataloader import DataLoader import transformers from accelerate import Accelerator, DistributedType from arguments import TrainingArguments from huggingface_hub import Repository from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed class ConstantLengthDataset(IterableDataset): """ Iterable dataset that returns constant length chunks of tokens from stream of text files. Args: tokenizer (Tokenizer): The processor used for proccessing the data. dataset (dataset.Dataset): Dataset with text files. infinite (bool): If True the iterator is reset after dataset reaches end else stops. seq_length (int): Length of token sequences to return. num_of_sequences: Number of token sequences to keep in buffer. chars_per_token: Number of characters per token used to estimate number of tokens in text buffer. tokenized: If true we use a pretokenized dataset. """ def __init__( self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6, tokenized=False, ): self.tokenizer = tokenizer self.concat_token_id = tokenizer.bos_token_id self.dataset = dataset self.seq_length = seq_length self.epoch = 0 self.infinite = infinite self.current_size = 0 self.tokenized = tokenized if self.tokenized: self.max_buffer_size = seq_length * num_of_sequences self.content_field = "input_ids" else: self.max_buffer_size = seq_length * chars_per_token * num_of_sequences self.content_field = "content" def __iter__(self): iterator = iter(self.dataset) more_examples = True while more_examples: buffer, buffer_len = [], 0 while True: if buffer_len >= self.max_buffer_size: break try: buffer.append(next(iterator)[self.content_field]) buffer_len += len(buffer[-1]) except StopIteration: if self.infinite: iterator = iter(self.dataset) self.epoch += 1 logger.info(f"Dataset epoch: {self.epoch}") else: more_examples = False break if self.tokenized: tokenized_inputs = buffer else: tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] all_token_ids = [] for tokenized_input in tokenized_inputs: all_token_ids.extend(tokenized_input + [self.concat_token_id]) for i in range(0, len(all_token_ids), self.seq_length): input_ids = all_token_ids[i : i + self.seq_length] if len(input_ids) == self.seq_length: self.current_size += 1 yield torch.tensor(input_ids) def setup_logging(args): project_name = args.model_ckpt.split("/")[-1] logger = logging.getLogger(__name__) log_dir = Path(args.save_dir) / "log/" log_dir.mkdir(exist_ok=True) filename = f"debug_{accelerator.process_index}.log" logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], ) if accelerator.is_main_process: # we only want to setup logging once accelerator.init_trackers(project_name, vars(args)) run_name = accelerator.trackers[0].run.name logger.setLevel(logging.INFO) datasets.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info() else: run_name = "" logger.setLevel(logging.ERROR) datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() return logger, run_name def create_dataloaders(args): ds_kwargs = {"streaming": True} train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) train_dataset = ConstantLengthDataset( tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized ) valid_dataset = ConstantLengthDataset( tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized ) train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size) return train_dataloader, eval_dataloader def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]): params_with_wd, params_without_wd = [], [] for n, p in model.named_parameters(): if any(nd in n for nd in no_decay): params_without_wd.append(p) else: params_with_wd.append(p) return [ {"params": params_with_wd, "weight_decay": args.weight_decay}, {"params": params_without_wd, "weight_decay": 0.0}, ] def log_metrics(step, metrics): logger.info(f"Step {step}: {metrics}") if accelerator.is_main_process: accelerator.log(metrics, step) def compute_tflops(elapsed_time, accelerator, args): # TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf). config_model = accelerator.unwrap_model(model).config checkpoint_factor = 4 if args.gradient_checkpointing else 3 batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2) flops_per_iteration = factor * ( 1.0 + (args.seq_length / (6.0 * config_model.n_embd)) + (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd)) ) tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12)) return tflops def evaluate(args): model.eval() losses = [] for step, batch in enumerate(eval_dataloader): with torch.no_grad(): outputs = model(batch, labels=batch) loss = outputs.loss.repeat(args.valid_batch_size) losses.append(accelerator.gather(loss)) if args.max_eval_steps > 0 and step >= args.max_eval_steps: break losses = torch.cat(losses) loss = losses[: eval_dataloader.dataset.current_size].mean() try: perplexity = torch.exp(loss) except OverflowError: perplexity = float("inf") return loss.item(), perplexity.item() # Accelerator accelerator = Accelerator(log_with=["wandb", "tensorboard"]) acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()} # Settings parser = HfArgumentParser(TrainingArguments) args = parser.parse_args() args = Namespace(**vars(args), **acc_state) samples_per_step = accelerator.state.num_processes * args.train_batch_size set_seed(args.seed) # Clone model repository if accelerator.is_main_process: hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt) # Logging logger, run_name = setup_logging(args) logger.info(accelerator.state) # Checkout new branch on repo if accelerator.is_main_process: hf_repo.git_checkout(run_name, create_branch_ok=True) # Load model and tokenizer model = AutoModelForCausalLM.from_pretrained(args.save_dir) if args.gradient_checkpointing: model.gradient_checkpointing_enable() tokenizer = AutoTokenizer.from_pretrained(args.save_dir) # Load dataset and dataloader train_dataloader, eval_dataloader = create_dataloaders(args) # Prepare the optimizer and learning rate scheduler optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate) lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps, ) accelerator.register_for_checkpointing(lr_scheduler) def get_lr(): return optimizer.param_groups[0]["lr"] # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader ) # load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.load_state(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)] dirs.sort(key=os.path.getctime) path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last # Extract the step of the checkpoint to continue from there training_difference = os.path.splitext(path)[0] resume_step = int(training_difference.replace("step_", "")) # Train model model.train() completed_steps = 0 t_start = time.time() for step, batch in enumerate(train_dataloader, start=1): if args.resume_from_checkpoint and step < resume_step: continue # we need to skip steps until we reach the resumed step loss = model(batch, labels=batch, use_cache=False).loss log_metrics( step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()} ) loss = loss / args.gradient_accumulation_steps if step % args.gradient_accumulation_steps != 0: # Prevent backward from doing gradient all_reduce in every step if accelerator.distributed_type == DistributedType.MULTI_GPU: with model.no_sync(): accelerator.backward(loss) else: accelerator.backward(loss) else: accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad() completed_steps += 1 elapsed_time = time.time() - t_start tflops = compute_tflops(elapsed_time, accelerator, args) log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time}) t_start = time.time() if step % args.save_checkpoint_steps == 0: logger.info("Evaluating and saving model checkpoint") eval_loss, perplexity = evaluate(args) log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) accelerator.wait_for_everyone() save_dir = os.path.join(args.save_dir, f"step_{step}") accelerator.save_state(save_dir) if accelerator.is_main_process: hf_repo.push_to_hub(commit_message=f"step {step}") model.train() if completed_steps >= args.max_train_steps: break # Evaluate and save the last checkpoint logger.info("Evaluating and saving model after training") eval_loss, perplexity = evaluate(args) log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) save_dir = os.path.join(args.save_dir, f"step_{step}") accelerator.save_state(save_dir) if accelerator.is_main_process: hf_repo.push_to_hub(commit_message="final model")