fixed trainer tr_loss memory leak

This commit is contained in:
Stuart Mesham 2020-09-07 21:45:11 +02:00
parent 90ec78b514
commit 47226e4eb6

View File

@ -683,7 +683,7 @@ class Trainer:
self.global_step = 0 self.global_step = 0
logger.info(" Starting fine-tuning.") logger.info(" Starting fine-tuning.")
tr_loss = torch.tensor(0.0).to(self.args.device) tr_loss_scalar = 0.0
logging_loss_scalar = 0.0 logging_loss_scalar = 0.0
model.zero_grad() model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
@ -713,7 +713,7 @@ class Trainer:
epoch_pbar.update(1) epoch_pbar.update(1)
continue continue
tr_loss += self.training_step(model, inputs) tr_loss_scalar += self.training_step(model, inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps
@ -745,7 +745,6 @@ class Trainer:
self.global_step == 1 and self.args.logging_first_step self.global_step == 1 and self.args.logging_first_step
): ):
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
# backward compatibility for pytorch schedulers # backward compatibility for pytorch schedulers
logs["learning_rate"] = ( logs["learning_rate"] = (
@ -819,7 +818,7 @@ class Trainer:
delattr(self, "_past") delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss.item() / self.global_step) return TrainOutput(self.global_step, tr_loss_scalar / self.global_step)
def hyperparameter_search( def hyperparameter_search(
self, self,
@ -1024,7 +1023,7 @@ class Trainer:
else: else:
loss.backward() loss.backward()
return loss return loss.item()
def is_local_master(self) -> bool: def is_local_master(self) -> bool:
""" """