mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
parent
06bab00338
commit
a8aad0ec93
@ -1537,9 +1537,7 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
if self.no_cuda:
|
if self.no_cuda:
|
||||||
self.distributed_state = PartialState(cpu=True)
|
self.distributed_state = PartialState(cpu=True)
|
||||||
device = self.distributed_state.device
|
|
||||||
self._n_gpu = 0
|
self._n_gpu = 0
|
||||||
self.local_rank = self.distributed_state.local_process_index
|
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
local_rank = smp.local_rank()
|
local_rank = smp.local_rank()
|
||||||
device = torch.device("cuda", local_rank)
|
device = torch.device("cuda", local_rank)
|
||||||
@ -1548,11 +1546,12 @@ class TrainingArguments:
|
|||||||
elif self.deepspeed:
|
elif self.deepspeed:
|
||||||
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
|
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
|
||||||
self._n_gpu = 1
|
self._n_gpu = 1
|
||||||
device = self.distributed_state.device
|
|
||||||
else:
|
else:
|
||||||
self.distributed_state = PartialState(backend=self.xpu_backend)
|
self.distributed_state = PartialState(backend=self.xpu_backend)
|
||||||
device = self.distributed_state.device
|
|
||||||
self._n_gpu = 1
|
self._n_gpu = 1
|
||||||
|
if not is_sagemaker_mp_enabled():
|
||||||
|
device = self.distributed_state.device
|
||||||
|
self.local_rank = self.distributed_state.local_process_index
|
||||||
if (
|
if (
|
||||||
torch.distributed.is_available()
|
torch.distributed.is_available()
|
||||||
and torch.distributed.is_initialized()
|
and torch.distributed.is_initialized()
|
||||||
|
Loading…
Reference in New Issue
Block a user