diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 74061b1d22f..379881f6893 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -148,6 +148,7 @@ from .utils import ( is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_compile_available, + is_torch_neuroncore_available, is_torch_tpu_available, logging, ) @@ -1537,6 +1538,8 @@ class Trainer: if self.args.ddp_bucket_cap_mb is not None: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + if is_torch_neuroncore_available: + return model model = nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,