diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index fd38025424b..6920e597e9d 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -155,11 +155,10 @@ def main(): help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") - parser.add_argument("--warmup_proportion", - default=0.1, - type=float, - help="Proportion of training to perform linear learning rate warmup for. " - "E.g., 0.1 = 10%% of training.") + parser.add_argument("--warmup_steps", + default=0, + type=int, + help="Linear warmup over warmup_steps.") parser.add_argument("--learning_rate", default=3e-5, type=float, @@ -270,13 +269,9 @@ def main(): optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) - warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, - t_total=num_train_optimization_steps) else: - optimizer = AdamW(optimizer_grouped_parameters, - lr=args.learning_rate, - warmup=args.warmup_proportion, - t_total=num_train_optimization_steps) + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps) global_step = 0 logging.info("***** Running training *****") @@ -298,7 +293,8 @@ def main(): for step, batch in enumerate(train_dataloader): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch - loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) + outputs = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) + loss = outputs[0] if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: @@ -314,18 +310,13 @@ def main(): mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps pbar.set_postfix_str(f"Loss: {mean_loss:.5f}") if (step + 1) % args.gradient_accumulation_steps == 0: - if args.fp16: - # modify learning rate with special warm up BERT uses - # if args.fp16 is False, BertAdam is used that handles this automatically - lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion) - for param_group in optimizer.param_groups: - param_group['lr'] = lr_this_step + scheduler.step() # Update learning rate schedule optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model - if torch.distributed.get_rank() == 0: + if n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 : logging.info("** ** * Saving fine-tuned model ** ** * ") model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self diff --git a/examples/lm_finetuning/simple_lm_finetuning.py b/examples/lm_finetuning/simple_lm_finetuning.py index 0783c1bcc34..6eff9e4140a 100644 --- a/examples/lm_finetuning/simple_lm_finetuning.py +++ b/examples/lm_finetuning/simple_lm_finetuning.py @@ -438,11 +438,10 @@ def main(): default=3.0, type=float, help="Total number of training epochs to perform.") - parser.add_argument("--warmup_proportion", - default=0.1, - type=float, - help="Proportion of training to perform linear learning rate warmup for. " - "E.g., 0.1 = 10%% of training.") + parser.add_argument("--warmup_steps", + default=0, + type=int, + help="Linear warmup over warmup_steps.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") @@ -504,7 +503,7 @@ def main(): if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) - if not os.path.exists(args.output_dir) and torch.distributed.get_rank() == 0: + if not os.path.exists(args.output_dir) and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 ): os.makedirs(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) @@ -558,14 +557,10 @@ def main(): optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) - warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, - t_total=num_train_optimization_steps) else: - optimizer = BertAdam(optimizer_grouped_parameters, - lr=args.learning_rate, - warmup=args.warmup_proportion, - t_total=num_train_optimization_steps) + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps) global_step = 0 if args.do_train: @@ -589,7 +584,8 @@ def main(): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch - loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) + outputs = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) + loss = outputs[0] if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: @@ -602,22 +598,17 @@ def main(): nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: - if args.fp16: - # modify learning rate with special warm up BERT uses - # if args.fp16 is False, BertAdam is used that handles this automatically - lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion) - for param_group in optimizer.param_groups: - param_group['lr'] = lr_this_step + scheduler.step() # Update learning rate schedule optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model - model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) - output_config_file = os.path.join(args.output_dir, CONFIG_NAME) - if args.do_train and torch.distributed.get_rank() == 0: + if args.do_train and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1): logger.info("** ** * Saving fine - tuned model ** ** * ") + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir)