diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index 9fcc5f2cb1a..ccf1c153130 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -314,8 +314,8 @@ def main(): mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps pbar.set_postfix_str(f"Loss: {mean_loss:.5f}") if (step + 1) % args.gradient_accumulation_steps == 0: - scheduler.step() # Update learning rate schedule optimizer.step() + scheduler.step() # Update learning rate schedule optimizer.zero_grad() global_step += 1