mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fixed trainer tr_loss memory leak
This commit is contained in:
parent
90ec78b514
commit
47226e4eb6
@ -683,7 +683,7 @@ class Trainer:
|
|||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
logger.info(" Starting fine-tuning.")
|
logger.info(" Starting fine-tuning.")
|
||||||
|
|
||||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
tr_loss_scalar = 0.0
|
||||||
logging_loss_scalar = 0.0
|
logging_loss_scalar = 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
||||||
@ -713,7 +713,7 @@ class Trainer:
|
|||||||
epoch_pbar.update(1)
|
epoch_pbar.update(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tr_loss += self.training_step(model, inputs)
|
tr_loss_scalar += self.training_step(model, inputs)
|
||||||
|
|
||||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||||
@ -745,7 +745,6 @@ class Trainer:
|
|||||||
self.global_step == 1 and self.args.logging_first_step
|
self.global_step == 1 and self.args.logging_first_step
|
||||||
):
|
):
|
||||||
logs: Dict[str, float] = {}
|
logs: Dict[str, float] = {}
|
||||||
tr_loss_scalar = tr_loss.item()
|
|
||||||
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
|
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
|
||||||
# backward compatibility for pytorch schedulers
|
# backward compatibility for pytorch schedulers
|
||||||
logs["learning_rate"] = (
|
logs["learning_rate"] = (
|
||||||
@ -819,7 +818,7 @@ class Trainer:
|
|||||||
delattr(self, "_past")
|
delattr(self, "_past")
|
||||||
|
|
||||||
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
||||||
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
|
return TrainOutput(self.global_step, tr_loss_scalar / self.global_step)
|
||||||
|
|
||||||
def hyperparameter_search(
|
def hyperparameter_search(
|
||||||
self,
|
self,
|
||||||
@ -1024,7 +1023,7 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
return loss
|
return loss.item()
|
||||||
|
|
||||||
def is_local_master(self) -> bool:
|
def is_local_master(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user