mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
fix: support grad clipping for TP through replicating non-sharded modules (#36132)
* feat: fix tp grad norm: Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * feat: use implicit replication Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
fca6748246
commit
3d15606e64
@ -232,6 +232,7 @@ if is_accelerate_available():
|
||||
AutocastKwargs,
|
||||
DistributedDataParallelKwargs,
|
||||
DistributedType,
|
||||
TorchTensorParallelPlugin,
|
||||
load_fsdp_model,
|
||||
load_fsdp_optimizer,
|
||||
save_fsdp_model,
|
||||
@ -2299,7 +2300,9 @@ class Trainer:
|
||||
else:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||
delay_optimizer_creation = (
|
||||
is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled
|
||||
)
|
||||
|
||||
# Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
|
||||
is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
|
||||
@ -2359,7 +2362,10 @@ class Trainer:
|
||||
if self.use_apex:
|
||||
model = self.accelerator.prepare(self.model)
|
||||
else:
|
||||
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||
if delay_optimizer_creation:
|
||||
self.optimizer = self.accelerator.prepare(self.optimizer)
|
||||
else:
|
||||
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||
else:
|
||||
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
|
||||
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||
@ -2580,10 +2586,16 @@ class Trainer:
|
||||
args.max_grad_norm,
|
||||
)
|
||||
else:
|
||||
_grad_norm = self.accelerator.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
grad_norm_context = contextlib.nullcontext
|
||||
if self.is_tp_enabled:
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
|
||||
grad_norm_context = implicit_replication
|
||||
with grad_norm_context():
|
||||
_grad_norm = self.accelerator.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
|
||||
if (
|
||||
is_accelerate_available()
|
||||
|
Loading…
Reference in New Issue
Block a user