fix bug in distributed loss test (#38166)

* fix bug in distributed loss test and change some config to pass at both 2&8 gpus

* fix doc
This commit is contained in:
kang sheng 2025-05-17 00:21:35 +08:00 committed by GitHub
parent a4389494c7
commit ea29f61ed9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 4 deletions

View File

@ -3784,7 +3784,7 @@ class Trainer:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps

View File

@ -26,7 +26,7 @@ class TestTrainerDistributedLoss(TestCasePlus):
@require_torch_multi_accelerator
def test_trainer(self):
device_count = backend_device_count(torch_device)
min_bs = 1
min_bs = 2
output_dir = self.get_auto_remove_tmp_dir()
for gpu_num, enable, bs, name in (
(1, True, min_bs * device_count, "base"),
@ -50,9 +50,10 @@ class TestTrainerDistributedLoss(TestCasePlus):
broken_diff = [abs(base_loss[i] - broken_loss[i]) for i in range(len(base_loss))]
fixed_diff = [abs(base_loss[i] - fixed_loss[i]) for i in range(len(base_loss))]
sum_base = sum(base_loss)
sum_broken = sum(broken_diff)
sum_broken = sum(broken_loss)
relative_broken = abs(sum_base - sum_broken) / max(sum_base, sum_broken)
# the gap may be smaller for other models, but it still ok.
self.assertGreater(max(broken_diff), 0.5)
self.assertLess(max(fixed_diff), 0.005)
self.assertLess(relative_broken, 0.1)
@ -63,7 +64,7 @@ def run_distributed_training(training_args):
model_name = "nickypro/tinyllama-15M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:17]")
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:100]")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token