From 47226e4eb6beee3b8c77e4063e57ab8df879fd2a Mon Sep 17 00:00:00 2001 From: Stuart Mesham Date: Mon, 7 Sep 2020 21:45:11 +0200 Subject: [PATCH] fixed trainer tr_loss memory leak --- src/transformers/trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cc4b2ee5b48..898a124541d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -683,7 +683,7 @@ class Trainer: self.global_step = 0 logger.info(" Starting fine-tuning.") - tr_loss = torch.tensor(0.0).to(self.args.device) + tr_loss_scalar = 0.0 logging_loss_scalar = 0.0 model.zero_grad() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() @@ -713,7 +713,7 @@ class Trainer: epoch_pbar.update(1) continue - tr_loss += self.training_step(model, inputs) + tr_loss_scalar += self.training_step(model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps @@ -745,7 +745,6 @@ class Trainer: self.global_step == 1 and self.args.logging_first_step ): logs: Dict[str, float] = {} - tr_loss_scalar = tr_loss.item() logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps # backward compatibility for pytorch schedulers logs["learning_rate"] = ( @@ -819,7 +818,7 @@ class Trainer: delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") - return TrainOutput(self.global_step, tr_loss.item() / self.global_step) + return TrainOutput(self.global_step, tr_loss_scalar / self.global_step) def hyperparameter_search( self, @@ -1024,7 +1023,7 @@ class Trainer: else: loss.backward() - return loss + return loss.item() def is_local_master(self) -> bool: """