Fixup multigpu local_rank (#22869)

Fixup multigpu tests
This commit is contained in:
Zachary Mueller 2023-04-19 14:37:16 -04:00 committed by GitHub
parent 06bab00338
commit a8aad0ec93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1537,9 +1537,7 @@ class TrainingArguments:
)
if self.no_cuda:
self.distributed_state = PartialState(cpu=True)
device = self.distributed_state.device
self._n_gpu = 0
self.local_rank = self.distributed_state.local_process_index
elif is_sagemaker_mp_enabled():
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
@ -1548,11 +1546,12 @@ class TrainingArguments:
elif self.deepspeed:
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
self._n_gpu = 1
device = self.distributed_state.device
else:
self.distributed_state = PartialState(backend=self.xpu_backend)
device = self.distributed_state.device
self._n_gpu = 1
if not is_sagemaker_mp_enabled():
device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()