Clean up dist import (#24402)

This commit is contained in:
Zach Mueller 2023-06-21 11:19:42 -04:00 committed by GitHub
parent 285a48011d
commit 1a6fb930fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,9 +87,9 @@ if is_torch_neuroncore_available(check_device=False):
) )
import torch_xla.distributed.xla_backend as xbn import torch_xla.distributed.xla_backend as xbn
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
torch.distributed.init_process_group(backend="xla") dist.init_process_group(backend="xla")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
@ -1716,11 +1716,7 @@ class TrainingArguments:
if not is_sagemaker_mp_enabled(): if not is_sagemaker_mp_enabled():
device = self.distributed_state.device device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index self.local_rank = self.distributed_state.local_process_index
if ( if dist.is_available() and dist.is_initialized() and self.parallel_mode != ParallelMode.DISTRIBUTED:
torch.distributed.is_available()
and torch.distributed.is_initialized()
and self.parallel_mode != ParallelMode.DISTRIBUTED
):
logger.warning( logger.warning(
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
@ -1963,10 +1959,8 @@ class TrainingArguments:
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous(desc) xm.rendezvous(desc)
elif is_sagemaker_dp_enabled():
dist.barrier()
else: else:
torch.distributed.barrier() dist.barrier()
yield yield
finally: finally:
if is_main_process: if is_main_process:
@ -1974,10 +1968,8 @@ class TrainingArguments:
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas") logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous(desc) xm.rendezvous(desc)
elif is_sagemaker_dp_enabled():
dist.barrier()
else: else:
torch.distributed.barrier() dist.barrier()
else: else:
yield yield