Move import check to before state reset (#23906)

* Move import check to before state reset

* Guard better
This commit is contained in:
Zachary Mueller 2023-05-31 10:49:43 -04:00 committed by GitHub
parent e42869b091
commit 84bac652f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1667,12 +1667,12 @@ class TrainingArguments:
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
AcceleratorState._reset_state()
PartialState._reset_state()
if not is_sagemaker_mp_enabled() and not is_accelerate_available(check_partial_state=True):
raise ImportError(
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
)
if not is_sagemaker_mp_enabled():
if not is_accelerate_available(check_partial_state=True):
raise ImportError(
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
)
AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None
if self.no_cuda:
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)