mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
parent
06bab00338
commit
a8aad0ec93
@ -1537,9 +1537,7 @@ class TrainingArguments:
|
||||
)
|
||||
if self.no_cuda:
|
||||
self.distributed_state = PartialState(cpu=True)
|
||||
device = self.distributed_state.device
|
||||
self._n_gpu = 0
|
||||
self.local_rank = self.distributed_state.local_process_index
|
||||
elif is_sagemaker_mp_enabled():
|
||||
local_rank = smp.local_rank()
|
||||
device = torch.device("cuda", local_rank)
|
||||
@ -1548,11 +1546,12 @@ class TrainingArguments:
|
||||
elif self.deepspeed:
|
||||
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
|
||||
self._n_gpu = 1
|
||||
device = self.distributed_state.device
|
||||
else:
|
||||
self.distributed_state = PartialState(backend=self.xpu_backend)
|
||||
device = self.distributed_state.device
|
||||
self._n_gpu = 1
|
||||
if not is_sagemaker_mp_enabled():
|
||||
device = self.distributed_state.device
|
||||
self.local_rank = self.distributed_state.local_process_index
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
|
Loading…
Reference in New Issue
Block a user