diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ca03db99471..d7c567f5391 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -232,6 +232,7 @@ if is_accelerate_available(): AutocastKwargs, DistributedDataParallelKwargs, DistributedType, + TorchTensorParallelPlugin, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, @@ -2299,7 +2300,9 @@ class Trainer: else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + delay_optimizer_creation = ( + is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled + ) # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2) @@ -2359,7 +2362,10 @@ class Trainer: if self.use_apex: model = self.accelerator.prepare(self.model) else: - model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + if delay_optimizer_creation: + self.optimizer = self.accelerator.prepare(self.optimizer) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( @@ -2580,10 +2586,16 @@ class Trainer: args.max_grad_norm, ) else: - _grad_norm = self.accelerator.clip_grad_norm_( - model.parameters(), - args.max_grad_norm, - ) + grad_norm_context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + grad_norm_context = implicit_replication + with grad_norm_context(): + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) if ( is_accelerate_available()