Adds timeout argument to training_args to avoid socket timeouts in DDP (#18562)

* chore(training_args): Adds support for timeout argument.

* fix(training_args): Passes make style through changes.

* fix(training_args): Removes wrong docstring sentence.

* fix(training_args): Fixes timeout not being JSON serializable.

* fix(training_args_sm): Also updates timeout to timeout_delta.

* fix(training_args): Fixes PR according to suggestions.
This commit is contained in:
Gustavo de Rosa 2022-09-01 11:33:53 -03:00 committed by GitHub
parent ab663b2274
commit fe58929ad6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 5 deletions

View File

@ -92,7 +92,7 @@ class SageMakerTrainingArguments(TrainingArguments):
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
torch.distributed.init_process_group(backend="smddp")
torch.distributed.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
@ -111,7 +111,7 @@ class SageMakerTrainingArguments(TrainingArguments):
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1

View File

@ -18,6 +18,7 @@ import math
import os
import warnings
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
@ -481,6 +482,11 @@ class TrainingArguments:
are also available. See the [Ray documentation](
https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
more options.
ddp_timeout (`int`, *optional*, defaults to 1800):
The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when
performing slow operations in distributed runnings. Please refer the [PyTorch documentation]
(https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
information.
use_mps_device (`bool`, *optional*, defaults to `False`):
Whether to use Apple Silicon chip based `mps` device.
"""
@ -971,6 +977,12 @@ class TrainingArguments:
)
},
)
ddp_timeout: Optional[int] = field(
default=1800,
metadata={
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
},
)
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
@ -1291,6 +1303,13 @@ class TrainingArguments:
eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
return eval_batch_size
@property
def ddp_timeout_delta(self) -> timedelta:
"""
The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable.
"""
return timedelta(seconds=self.ddp_timeout)
@cached_property
@torch_required
def _setup_devices(self) -> "torch.device":
@ -1358,7 +1377,9 @@ class TrainingArguments:
f"num_cpu_threads_per_process unset, we set it at {num_cpu_threads_per_process} to improve oob"
" performance."
)
torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size)
torch.distributed.init_process_group(
backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta
)
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0
@ -1369,7 +1390,7 @@ class TrainingArguments:
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
dist.init_process_group(backend="smddp")
dist.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
@ -1431,7 +1452,7 @@ class TrainingArguments:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1