Workaround for #27758 to avoid ZeroDivisionError (#28756)

This commit is contained in:
Traun Leyden 2024-03-04 10:23:40 +01:00 committed by GitHub
parent 704b3f74f9
commit c38a12270a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2080,7 +2080,8 @@ class Trainer:
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
train_loss = self._total_loss_scalar / effective_global_step
metrics = speed_metrics(
"train",