mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Clean up dist import (#24402)
This commit is contained in:
parent
285a48011d
commit
1a6fb930fb
@ -87,9 +87,9 @@ if is_torch_neuroncore_available(check_device=False):
|
|||||||
)
|
)
|
||||||
import torch_xla.distributed.xla_backend as xbn
|
import torch_xla.distributed.xla_backend as xbn
|
||||||
|
|
||||||
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
|
if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
|
||||||
torch.distributed.init_process_group(backend="xla")
|
dist.init_process_group(backend="xla")
|
||||||
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
|
if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
|
||||||
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
|
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
|
||||||
|
|
||||||
|
|
||||||
@ -1716,11 +1716,7 @@ class TrainingArguments:
|
|||||||
if not is_sagemaker_mp_enabled():
|
if not is_sagemaker_mp_enabled():
|
||||||
device = self.distributed_state.device
|
device = self.distributed_state.device
|
||||||
self.local_rank = self.distributed_state.local_process_index
|
self.local_rank = self.distributed_state.local_process_index
|
||||||
if (
|
if dist.is_available() and dist.is_initialized() and self.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
torch.distributed.is_available()
|
|
||||||
and torch.distributed.is_initialized()
|
|
||||||
and self.parallel_mode != ParallelMode.DISTRIBUTED
|
|
||||||
):
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
|
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
|
||||||
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
|
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
|
||||||
@ -1963,10 +1959,8 @@ class TrainingArguments:
|
|||||||
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.rendezvous(desc)
|
xm.rendezvous(desc)
|
||||||
elif is_sagemaker_dp_enabled():
|
|
||||||
dist.barrier()
|
|
||||||
else:
|
else:
|
||||||
torch.distributed.barrier()
|
dist.barrier()
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
@ -1974,10 +1968,8 @@ class TrainingArguments:
|
|||||||
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
|
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.rendezvous(desc)
|
xm.rendezvous(desc)
|
||||||
elif is_sagemaker_dp_enabled():
|
|
||||||
dist.barrier()
|
|
||||||
else:
|
else:
|
||||||
torch.distributed.barrier()
|
dist.barrier()
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user