From b70f441b72accf3205185290efc563c0dea65bfc Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 3 Mar 2021 12:13:29 -0500 Subject: [PATCH] Smp grad accum (#10488) * Fix gradient accumulation for SM Model Parallelism * Style and divide loss by grad accum steps --- src/transformers/sagemaker/trainer_sm.py | 9 ++++----- src/transformers/sagemaker/training_args_sm.py | 4 ++++ src/transformers/trainer.py | 2 +- src/transformers/training_args.py | 7 +++++++ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/transformers/sagemaker/trainer_sm.py b/src/transformers/sagemaker/trainer_sm.py index a104ee4426b..c82114acf39 100644 --- a/src/transformers/sagemaker/trainer_sm.py +++ b/src/transformers/sagemaker/trainer_sm.py @@ -37,9 +37,10 @@ if is_smdistributed_available(): import smdistributed.modelparallel.torch as smp @smp.step() - def forward_backward(model, inputs): + def forward_backward(model, inputs, gradient_accumulation_steps=1): outputs = model(**inputs) loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + loss /= gradient_accumulation_steps model.backward(loss) return loss @@ -73,8 +74,6 @@ class SageMakerTrainer(Trainer): def __init__(self, args=None, **kwargs): self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != "" super().__init__(args=args, **kwargs) - if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1: - raise ValueError("Gradient accumulation is not supported when model parallel is enabled.") def is_world_process_zero(self) -> bool: """ @@ -108,7 +107,7 @@ class SageMakerTrainer(Trainer): # Wrapping the base model twice in a DistributedModel will raise an error. if isinstance(self.model_wrapped, smp.model.DistributedModel): return self.model_wrapped - return smp.DistributedModel(model) + return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) else: return super()._wrap_model(model) @@ -121,7 +120,7 @@ class SageMakerTrainer(Trainer): if self.is_model_parallel_enabled: model.train() inputs = self._prepare_inputs(inputs) - loss_mb = forward_backward(model, inputs) + loss_mb = forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device) else: return super().training_step(model, inputs) diff --git a/src/transformers/sagemaker/training_args_sm.py b/src/transformers/sagemaker/training_args_sm.py index 0aaef833caa..9b181fc4657 100644 --- a/src/transformers/sagemaker/training_args_sm.py +++ b/src/transformers/sagemaker/training_args_sm.py @@ -87,3 +87,7 @@ class SageMakerTrainingArguments(TrainingArguments): @property def place_model_on_device(self): return not (is_smdistributed_available() and self.mp_parameters != "") + + @property + def _no_sync_in_gradient_accumulation(self): + return False diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 874c3ef5230..50c6a96839f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1039,7 +1039,7 @@ class Trainer: if ( ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1 - and not self.args.deepspeed + and self.args._no_sync_in_gradient_accumulation ): # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. with model.no_sync(): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c683cb13a3d..7bd5964bf2e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -737,6 +737,13 @@ class TrainingArguments: """ return True + @property + def _no_sync_in_gradient_accumulation(self): + """ + Whether or not to use no_sync for the gradients when doing gradient accumulation. + """ + return not self.deepspeed + def to_dict(self): """ Serializes this instance while replace `Enum` by their values (for JSON serialization support).