mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
refactor deepspeed setup devices (#9880)
This commit is contained in:
parent
6bf94bc0b6
commit
1420b5ff67
@ -535,6 +535,20 @@ class TrainingArguments:
|
||||
self.local_rank = dist.get_local_rank()
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
elif self.deepspeed:
|
||||
# deepspeed performs its own DDP internally, and requires the program to be started with:
|
||||
# deepspeed ./program.py
|
||||
# rather than:
|
||||
# python -m torch.distributed.launch --nproc_per_node=2 ./program.py
|
||||
from .integrations import is_deepspeed_available
|
||||
|
||||
if not is_deepspeed_available():
|
||||
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
|
||||
import deepspeed
|
||||
|
||||
deepspeed.init_distributed()
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
elif self.local_rank == -1:
|
||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
@ -549,21 +563,7 @@ class TrainingArguments:
|
||||
else:
|
||||
# Here, we'll use torch.distributed.
|
||||
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
#
|
||||
# deepspeed performs its own DDP internally, and requires the program to be started with:
|
||||
# deepspeed ./program.py
|
||||
# rather than:
|
||||
# python -m torch.distributed.launch --nproc_per_node=2 ./program.py
|
||||
if self.deepspeed:
|
||||
from .integrations import is_deepspeed_available
|
||||
|
||||
if not is_deepspeed_available():
|
||||
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
|
||||
import deepspeed
|
||||
|
||||
deepspeed.init_distributed()
|
||||
else:
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user