From d00997e66c0f4f6e55a1a422d6d3ca8fb1d37aab Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 21 Apr 2023 11:42:02 -0400 Subject: [PATCH] ddp fixes for training (#22874) ddp fixes for stable lm training --- src/transformers/trainer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7a895c0a7a3..147cdc8d9ad 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1565,12 +1565,13 @@ class Trainer: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb if is_torch_neuroncore_available(): return model - model = nn.parallel.DistributedDataParallel( - model, - device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None, - output_device=self.args.local_rank if self.args._n_gpu != 0 else None, - **kwargs, - ) + if any(p.requires_grad for p in model.parameters()): + model = nn.parallel.DistributedDataParallel( + model, + device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None, + output_device=self.args.local_rank if self.args._n_gpu != 0 else None, + **kwargs, + ) # torch.compile() needs to be called after wrapping the model with FSDP or DDP # to ensure that it accounts for the graph breaks required by those wrappers @@ -1920,6 +1921,7 @@ class Trainer: (total_batched_samples % args.gradient_accumulation_steps != 0) and args.parallel_mode == ParallelMode.DISTRIBUTED and args._no_sync_in_gradient_accumulation + and hasattr(model, "no_sync") ): # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. with model.no_sync():