Add basic support for FP16 in SageMaker model parallelism (#11407)

* Add FP16 support for SageMaker MP

* Add print debugs

* Squeeze

* Remove debug statements

* Add defensive check

* Typo
This commit is contained in:
Sylvain Gugger 2021-04-26 08:55:14 -04:00 committed by GitHub
parent 38a716cd41
commit d7633a4e46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 4 deletions

View File

@ -412,7 +412,12 @@ class Trainer:
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp":
self.use_amp = True
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
if is_sagemaker_mp_enabled():
self.scaler = smp.amp.GradScaler()
elif self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
else:
self.scaler = torch.cuda.amp.GradScaler()
else:
if not is_apex_available():
raise ImportError(
@ -420,6 +425,13 @@ class Trainer:
)
self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0:
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
)
# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
@ -1607,7 +1619,8 @@ class Trainer:
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
scaler = self.scaler if self.use_amp else None
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp:

View File

@ -974,10 +974,15 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
@smp.step()
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
outputs = model(**inputs)
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1, scaler=None):
with torch.cuda.amp.autocast(enabled=(scaler is not None)):
outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps
if scaler is not None:
loss = scaler.scale(loss).squeeze()
model.backward(loss)
return loss