fix: support grad clipping for TP through replicating non-sharded modules (#36132)

* feat: fix tp grad norm:

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: use implicit replication

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Mehant Kammakomati 2025-06-06 14:37:22 +05:30 committed by GitHub
parent fca6748246
commit 3d15606e64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -232,6 +232,7 @@ if is_accelerate_available():
AutocastKwargs,
DistributedDataParallelKwargs,
DistributedType,
TorchTensorParallelPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
@ -2299,7 +2300,9 @@ class Trainer:
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
delay_optimizer_creation = (
is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled
)
# Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
@ -2359,7 +2362,10 @@ class Trainer:
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
if delay_optimizer_creation:
self.optimizer = self.accelerator.prepare(self.optimizer)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
@ -2580,10 +2586,16 @@ class Trainer:
args.max_grad_norm,
)
else:
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
grad_norm_context = contextlib.nullcontext
if self.is_tp_enabled:
from torch.distributed._tensor.experimental import implicit_replication
grad_norm_context = implicit_replication
with grad_norm_context():
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
if (
is_accelerate_available()