diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index 771109b1d38..3d9b20bdc08 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -59,7 +59,7 @@ class ConstantLengthDataset(IterableDataset): else: more_examples = False break - tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"] + tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] all_token_ids = [] for tokenized_input in tokenized_inputs: all_token_ids.extend(tokenized_input + [self.concat_token_id]) @@ -68,173 +68,176 @@ class ConstantLengthDataset(IterableDataset): if len(input_ids) == self.seq_length: yield torch.tensor(input_ids) - -def setup_logging(args): - project_name = args.model_ckpt.split("/")[-1] - logger = logging.getLogger(__name__) - log_dir = Path(args.save_dir) / "log/" - log_dir.mkdir(exist_ok=True) - filename = f"debug_{accelerator.process_index}.log" - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], - ) - if accelerator.is_main_process: # we only want to setup logging once - wandb.init(project=project_name, config=args) - run_name = wandb.run.name - tb_writer = SummaryWriter() - tb_writer.add_hparams(vars(args), {"0": 0}) - logger.setLevel(logging.INFO) - datasets.utils.logging.set_verbosity_info() - transformers.utils.logging.set_verbosity_info() - else: - tb_writer = None - run_name = "" - logger.setLevel(logging.ERROR) - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - return logger, tb_writer, run_name - - -def create_dataloaders(args): - ds_kwargs = {"streaming": True} - train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) - train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) - valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) - train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length) - valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length) - train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) - eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size) - return train_dataloader, eval_dataloader - - -def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]): - params_with_wd, params_without_wd = [], [] - for n, p in model.named_parameters(): - if any(nd in n for nd in no_decay): - params_without_wd.append(p) +def main(): + def setup_logging(args): + project_name = args.model_ckpt.split("/")[-1] + logger = logging.getLogger(__name__) + log_dir = Path(args.save_dir) / "log/" + log_dir.mkdir(exist_ok=True) + filename = f"debug_{accelerator.process_index}.log" + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], + ) + if accelerator.is_main_process: # we only want to setup logging once + wandb.init(project=project_name, config=args) + run_name = wandb.run.name + tb_writer = SummaryWriter() + tb_writer.add_hparams(vars(args), {"0": 0}) + logger.setLevel(logging.INFO) + datasets.utils.logging.set_verbosity_info() + transformers.utils.logging.set_verbosity_info() else: - params_with_wd.append(p) - return [ - {"params": params_with_wd, "weight_decay": args.weight_decay}, - {"params": params_without_wd, "weight_decay": 0.0}, - ] + tb_writer = None + run_name = "" + logger.setLevel(logging.ERROR) + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + return logger, tb_writer, run_name -def log_metrics(step, metrics): - logger.info(f"Step {step}: {metrics}") - if accelerator.is_main_process: - wandb.log(metrics) - [tb_writer.add_scalar(k, v, step) for k, v in metrics.items()] + def create_dataloaders(args): + ds_kwargs = {"streaming": True} + train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) + valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) + train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length) + valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length) + train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) + eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size) + return train_dataloader, eval_dataloader -def evaluate(args): - model.eval() - losses = [] - for step, batch in enumerate(eval_dataloader): - with torch.no_grad(): - outputs = model(batch, labels=batch) - loss = outputs.loss.repeat(args.valid_batch_size) - losses.append(accelerator.gather(loss)) - if args.max_eval_steps > 0 and step >= args.max_eval_steps: - break - loss = torch.mean(torch.cat(losses)) - try: - perplexity = torch.exp(loss) - except OverflowError: - perplexity = float("inf") - return loss.item(), perplexity.item() + def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]): + params_with_wd, params_without_wd = [], [] + for n, p in model.named_parameters(): + if any(nd in n for nd in no_decay): + params_without_wd.append(p) + else: + params_with_wd.append(p) + return [ + {"params": params_with_wd, "weight_decay": args.weight_decay}, + {"params": params_without_wd, "weight_decay": 0.0}, + ] -# Accelerator -accelerator = Accelerator() -acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()} - -# Settings -parser = HfArgumentParser(TrainingArguments) -args = parser.parse_args() - -args = Namespace(**vars(args), **acc_state) -samples_per_step = accelerator.state.num_processes * args.train_batch_size -set_seed(args.seed) - -# Clone model repository -if accelerator.is_main_process: - hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt) - -# Logging -logger, tb_writer, run_name = setup_logging(args) -logger.info(accelerator.state) - -# Checkout new branch on repo -if accelerator.is_main_process: - hf_repo.git_checkout(run_name, create_branch_ok=True) - -# Load model and tokenizer -model = AutoModelForCausalLM.from_pretrained(args.save_dir) -if args.gradient_checkpointing: - model.gradient_checkpointing_enable() -tokenizer = AutoTokenizer.from_pretrained(args.save_dir) - -# Load dataset and dataloader -train_dataloader, eval_dataloader = create_dataloaders(args) - -# Prepare the optimizer and learning rate scheduler -optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate) -lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps, - num_training_steps=args.max_train_steps, -) - - -def get_lr(): - return optimizer.param_groups[0]["lr"] - - -# Prepare everything with our `accelerator`. -model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( - model, optimizer, train_dataloader, eval_dataloader -) - -# Train model -model.train() -completed_steps = 0 -for step, batch in enumerate(train_dataloader, start=1): - loss = model(batch, labels=batch, use_cache=False).loss - log_metrics( - step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()} - ) - loss = loss / args.gradient_accumulation_steps - accelerator.backward(loss) - if step % args.gradient_accumulation_steps == 0: - accelerator.clip_grad_norm_(model.parameters(), 1.0) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - completed_steps += 1 - if step % args.save_checkpoint_steps == 0: - logger.info("Evaluating and saving model checkpoint") - eval_loss, perplexity = evaluate(args) - log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) + def log_metrics(step, metrics): + logger.info(f"Step {step}: {metrics}") if accelerator.is_main_process: - hf_repo.push_to_hub(commit_message=f"step {step}") - model.train() - if completed_steps >= args.max_train_steps: - break + wandb.log(metrics) + [tb_writer.add_scalar(k, v, step) for k, v in metrics.items()] -# Evaluate and save the last checkpoint -logger.info("Evaluating and saving model after training") -eval_loss, perplexity = evaluate(args) -log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) -accelerator.wait_for_everyone() -unwrapped_model = accelerator.unwrap_model(model) -unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) -if accelerator.is_main_process: - hf_repo.push_to_hub(commit_message="final model") + + def evaluate(args): + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(batch, labels=batch) + loss = outputs.loss.repeat(args.valid_batch_size) + losses.append(accelerator.gather(loss)) + if args.max_eval_steps > 0 and step >= args.max_eval_steps: + break + loss = torch.mean(torch.cat(losses)) + try: + perplexity = torch.exp(loss) + except OverflowError: + perplexity = float("inf") + return loss.item(), perplexity.item() + + + # Accelerator + accelerator = Accelerator() + acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()} + + # Settings + parser = HfArgumentParser(TrainingArguments) + args = parser.parse_args() + + args = Namespace(**vars(args), **acc_state) + samples_per_step = accelerator.state.num_processes * args.train_batch_size + set_seed(args.seed) + + # Clone model repository + if accelerator.is_main_process: + hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt) + + # Logging + logger, tb_writer, run_name = setup_logging(args) + logger.info(accelerator.state) + + # Checkout new branch on repo + if accelerator.is_main_process: + hf_repo.git_checkout(run_name, create_branch_ok=True) + + # Load model and tokenizer + model = AutoModelForCausalLM.from_pretrained(args.save_dir) + if args.gradient_checkpointing: + model.gradient_checkpointing_enable() + tokenizer = AutoTokenizer.from_pretrained(args.save_dir) + + # Load dataset and dataloader + train_dataloader, eval_dataloader = create_dataloaders(args) + + # Prepare the optimizer and learning rate scheduler + optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate) + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + + def get_lr(): + return optimizer.param_groups[0]["lr"] + + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader + ) + + # Train model + model.train() + completed_steps = 0 + for step, batch in enumerate(train_dataloader, start=1): + loss = model(batch, labels=batch, use_cache=False).loss + log_metrics( + step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()} + ) + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0: + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + completed_steps += 1 + if step % args.save_checkpoint_steps == 0: + logger.info("Evaluating and saving model checkpoint") + eval_loss, perplexity = evaluate(args) + log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) + if accelerator.is_main_process: + hf_repo.push_to_hub(commit_message=f"step {step}") + model.train() + if completed_steps >= args.max_train_steps: + break + + # Evaluate and save the last checkpoint + logger.info("Evaluating and saving model after training") + eval_loss, perplexity = evaluate(args) + log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) + if accelerator.is_main_process: + hf_repo.push_to_hub(commit_message="final model") + +if __name__ == "__main__": + main() \ No newline at end of file