mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Scale loss before backward (#35207)
This commit is contained in:
parent
f5264a86ee
commit
3cd3cd50ac
@ -3698,10 +3698,12 @@ class Trainer:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
self.accelerator.backward(loss, **kwargs)
|
||||
# Finally we need to normalize the loss for reporting
|
||||
if num_items_in_batch is None:
|
||||
return loss.detach() / self.args.gradient_accumulation_steps
|
||||
loss /= self.args.gradient_accumulation_steps
|
||||
|
||||
self.accelerator.backward(loss, **kwargs)
|
||||
|
||||
return loss.detach()
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
|
Loading…
Reference in New Issue
Block a user