[Trainer] Make sure shown loss in distributed training is correctly averaged over all workers (#13681)

* push

* improve tr loss gather
This commit is contained in:
Patrick von Platen 2021-09-26 09:03:45 +02:00 committed by GitHub
parent 044eff5bf0
commit 91df45516c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1462,7 +1462,10 @@ class Trainer:
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
# reset tr_loss to zero
tr_loss -= tr_loss