diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index fc2ba427867..4ab1829859e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1537,9 +1537,7 @@ class TrainingArguments: ) if self.no_cuda: self.distributed_state = PartialState(cpu=True) - device = self.distributed_state.device self._n_gpu = 0 - self.local_rank = self.distributed_state.local_process_index elif is_sagemaker_mp_enabled(): local_rank = smp.local_rank() device = torch.device("cuda", local_rank) @@ -1548,11 +1546,12 @@ class TrainingArguments: elif self.deepspeed: self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) self._n_gpu = 1 - device = self.distributed_state.device else: self.distributed_state = PartialState(backend=self.xpu_backend) - device = self.distributed_state.device self._n_gpu = 1 + if not is_sagemaker_mp_enabled(): + device = self.distributed_state.device + self.local_rank = self.distributed_state.local_process_index if ( torch.distributed.is_available() and torch.distributed.is_initialized()