ddp fixes for training (#22874)

ddp fixes for stable lm training
This commit is contained in:
Wing Lian 2023-04-21 11:42:02 -04:00 committed by GitHub
parent eddf9eeca0
commit d00997e66c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1565,12 +1565,13 @@ class Trainer:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if is_torch_neuroncore_available():
return model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
if any(p.requires_grad for p in model.parameters()):
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
@ -1920,6 +1921,7 @@ class Trainer:
(total_batched_samples % args.gradient_accumulation_steps != 0)
and args.parallel_mode == ParallelMode.DISTRIBUTED
and args._no_sync_in_gradient_accumulation
and hasattr(model, "no_sync")
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():