mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[Trainer] Make sure shown loss in distributed training is correctly averaged over all workers (#13681)
* push * improve tr loss gather
This commit is contained in:
parent
044eff5bf0
commit
91df45516c
@ -1462,7 +1462,10 @@ class Trainer:
|
||||
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
|
||||
if self.control.should_log:
|
||||
logs: Dict[str, float] = {}
|
||||
tr_loss_scalar = tr_loss.item()
|
||||
|
||||
# all_gather + mean() to get average loss over all processes
|
||||
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
|
||||
|
||||
# reset tr_loss to zero
|
||||
tr_loss -= tr_loss
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user