mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add basic support for FP16 in SageMaker model parallelism (#11407)
* Add FP16 support for SageMaker MP * Add print debugs * Squeeze * Remove debug statements * Add defensive check * Typo
This commit is contained in:
parent
38a716cd41
commit
d7633a4e46
@ -412,7 +412,12 @@ class Trainer:
|
||||
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
|
||||
if self.fp16_backend == "amp":
|
||||
self.use_amp = True
|
||||
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.scaler = smp.amp.GradScaler()
|
||||
elif self.sharded_ddp is not None:
|
||||
self.scaler = ShardedGradScaler()
|
||||
else:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
else:
|
||||
if not is_apex_available():
|
||||
raise ImportError(
|
||||
@ -420,6 +425,13 @@ class Trainer:
|
||||
)
|
||||
self.use_apex = True
|
||||
|
||||
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
|
||||
if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0:
|
||||
raise ValueError(
|
||||
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
|
||||
"along 'max_grad_norm': 0 in your hyperparameters."
|
||||
)
|
||||
|
||||
# Label smoothing
|
||||
if self.args.label_smoothing_factor != 0:
|
||||
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
|
||||
@ -1607,7 +1619,8 @@ class Trainer:
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
||||
scaler = self.scaler if self.use_amp else None
|
||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
|
||||
if self.use_amp:
|
||||
|
@ -974,10 +974,15 @@ if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
@smp.step()
|
||||
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
|
||||
outputs = model(**inputs)
|
||||
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1, scaler=None):
|
||||
with torch.cuda.amp.autocast(enabled=(scaler is not None)):
|
||||
outputs = model(**inputs)
|
||||
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
loss /= gradient_accumulation_steps
|
||||
if scaler is not None:
|
||||
loss = scaler.scale(loss).squeeze()
|
||||
|
||||
model.backward(loss)
|
||||
return loss
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user