From 87b9ec3843f7f9a81253075f92c9e6537ecefe1c Mon Sep 17 00:00:00 2001 From: Mathieu Prouveur Date: Mon, 29 Apr 2019 12:58:29 +0200 Subject: [PATCH] Fix tr_loss rescaling factor using global_step --- examples/run_classifier.py | 6 +++--- examples/run_swag.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index e14788cacb0..f678525b155 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -845,7 +845,7 @@ def main(): else: loss.backward() - tr_loss += loss.item() * args.gradient_accumulation_steps + tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: @@ -936,7 +936,7 @@ def main(): elif output_mode == "regression": preds = np.squeeze(preds) result = compute_metrics(task_name, preds, all_label_ids.numpy()) - loss = tr_loss/nb_tr_steps if args.do_train else None + loss = tr_loss/global_step if args.do_train else None result['eval_loss'] = eval_loss result['global_step'] = global_step @@ -1004,7 +1004,7 @@ def main(): preds = preds[0] preds = np.argmax(preds, axis=1) result = compute_metrics(task_name, preds, all_label_ids.numpy()) - loss = tr_loss/nb_tr_steps if args.do_train else None + loss = tr_loss/global_step if args.do_train else None result['eval_loss'] = eval_loss result['global_step'] = global_step diff --git a/examples/run_swag.py b/examples/run_swag.py index 5a65d7a7487..4fb32549cbd 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -452,7 +452,7 @@ def main(): loss = loss * args.loss_scale if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps - tr_loss += loss.item() * args.gradient_accumulation_steps + tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 @@ -537,7 +537,7 @@ def main(): result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy, 'global_step': global_step, - 'loss': tr_loss/nb_tr_steps} + 'loss': tr_loss/global_step} output_eval_file = os.path.join(args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: