Fix condition when GA loss bug fix is not performed (#35651)

* fix condition when GA loss bug fix is not performed

* max loss diff is 2.29

* fix typo

* add an extra validation that loss should not vary too much
This commit is contained in:
kang sheng 2025-01-16 20:59:53 +08:00 committed by GitHub
parent fd4f14c968
commit 2cbcc5877d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 10 deletions

View File

@ -3672,10 +3672,7 @@ class Trainer:
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
@ -3709,7 +3706,7 @@ class Trainer:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
self.accelerator.backward(loss, **kwargs)
@ -5157,10 +5154,6 @@ class Trainer:
except StopIteration:
break
# Keep default behavior the same
if not self.model_accepts_loss_kwargs:
return batch_samples, None
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:

View File

@ -855,7 +855,14 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 2")
loss_base = sum(base_loss_callback.losses)
loss_broken = sum(broken_loss_callback.losses)
# mean/sum loss should not vary too much.
relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken)
self.assertLess(relative_diff, 0.1, f"Relative difference {relative_diff} is not within 0.1")
@slow
def test_gradient_accumulation_loss_alignment_with_loss_func(self):