diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ad0df7f99d1..dc692f5aa74 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -87,9 +87,9 @@ if is_torch_neuroncore_available(check_device=False): ) import torch_xla.distributed.xla_backend as xbn - if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): - torch.distributed.init_process_group(backend="xla") - if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): + if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla): + dist.init_process_group(backend="xla") + if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla): raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") @@ -1716,11 +1716,7 @@ class TrainingArguments: if not is_sagemaker_mp_enabled(): device = self.distributed_state.device self.local_rank = self.distributed_state.local_process_index - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - and self.parallel_mode != ParallelMode.DISTRIBUTED - ): + if dist.is_available() and dist.is_initialized() and self.parallel_mode != ParallelMode.DISTRIBUTED: logger.warning( "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" @@ -1963,10 +1959,8 @@ class TrainingArguments: logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") if is_torch_tpu_available(): xm.rendezvous(desc) - elif is_sagemaker_dp_enabled(): - dist.barrier() else: - torch.distributed.barrier() + dist.barrier() yield finally: if is_main_process: @@ -1974,10 +1968,8 @@ class TrainingArguments: logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas") if is_torch_tpu_available(): xm.rendezvous(desc) - elif is_sagemaker_dp_enabled(): - dist.barrier() else: - torch.distributed.barrier() + dist.barrier() else: yield