Scale loss before backward (#35207)

This commit is contained in:
Quentin Gallouédec 2024-12-23 16:16:38 +01:00 committed by GitHub
parent f5264a86ee
commit 3cd3cd50ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3698,10 +3698,12 @@ class Trainer:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss, **kwargs)
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
return loss.detach() / self.args.gradient_accumulation_steps
loss /= self.args.gradient_accumulation_steps
self.accelerator.backward(loss, **kwargs)
return loss.detach()
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):