mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Adds timeout argument to training_args to avoid socket timeouts in DDP (#18562)
* chore(training_args): Adds support for timeout argument. * fix(training_args): Passes make style through changes. * fix(training_args): Removes wrong docstring sentence. * fix(training_args): Fixes timeout not being JSON serializable. * fix(training_args_sm): Also updates timeout to timeout_delta. * fix(training_args): Fixes PR according to suggestions.
This commit is contained in:
parent
ab663b2274
commit
fe58929ad6
@ -92,7 +92,7 @@ class SageMakerTrainingArguments(TrainingArguments):
|
||||
elif is_sagemaker_dp_enabled():
|
||||
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
|
||||
|
||||
torch.distributed.init_process_group(backend="smddp")
|
||||
torch.distributed.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
|
||||
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
@ -111,7 +111,7 @@ class SageMakerTrainingArguments(TrainingArguments):
|
||||
# Here, we'll use torch.distributed.
|
||||
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
|
||||
|
@ -18,6 +18,7 @@ import math
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@ -481,6 +482,11 @@ class TrainingArguments:
|
||||
are also available. See the [Ray documentation](
|
||||
https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
|
||||
more options.
|
||||
ddp_timeout (`int`, *optional*, defaults to 1800):
|
||||
The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when
|
||||
performing slow operations in distributed runnings. Please refer the [PyTorch documentation]
|
||||
(https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
|
||||
information.
|
||||
use_mps_device (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Apple Silicon chip based `mps` device.
|
||||
"""
|
||||
@ -971,6 +977,12 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
ddp_timeout: Optional[int] = field(
|
||||
default=1800,
|
||||
metadata={
|
||||
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
||||
@ -1291,6 +1303,13 @@ class TrainingArguments:
|
||||
eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
|
||||
return eval_batch_size
|
||||
|
||||
@property
|
||||
def ddp_timeout_delta(self) -> timedelta:
|
||||
"""
|
||||
The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable.
|
||||
"""
|
||||
return timedelta(seconds=self.ddp_timeout)
|
||||
|
||||
@cached_property
|
||||
@torch_required
|
||||
def _setup_devices(self) -> "torch.device":
|
||||
@ -1358,7 +1377,9 @@ class TrainingArguments:
|
||||
f"num_cpu_threads_per_process unset, we set it at {num_cpu_threads_per_process} to improve oob"
|
||||
" performance."
|
||||
)
|
||||
torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size)
|
||||
torch.distributed.init_process_group(
|
||||
backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta
|
||||
)
|
||||
elif is_torch_tpu_available():
|
||||
device = xm.xla_device()
|
||||
self._n_gpu = 0
|
||||
@ -1369,7 +1390,7 @@ class TrainingArguments:
|
||||
elif is_sagemaker_dp_enabled():
|
||||
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
|
||||
|
||||
dist.init_process_group(backend="smddp")
|
||||
dist.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
|
||||
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
@ -1431,7 +1452,7 @@ class TrainingArguments:
|
||||
# Here, we'll use torch.distributed.
|
||||
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user