mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Move import check to before state reset (#23906)
* Move import check to before state reset * Guard better
This commit is contained in:
parent
e42869b091
commit
84bac652f3
@ -1667,12 +1667,12 @@ class TrainingArguments:
|
||||
def _setup_devices(self) -> "torch.device":
|
||||
requires_backends(self, ["torch"])
|
||||
logger.info("PyTorch: setting up devices")
|
||||
AcceleratorState._reset_state()
|
||||
PartialState._reset_state()
|
||||
if not is_sagemaker_mp_enabled() and not is_accelerate_available(check_partial_state=True):
|
||||
raise ImportError(
|
||||
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
|
||||
)
|
||||
if not is_sagemaker_mp_enabled():
|
||||
if not is_accelerate_available(check_partial_state=True):
|
||||
raise ImportError(
|
||||
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
|
||||
)
|
||||
AcceleratorState._reset_state(reset_partial_state=True)
|
||||
self.distributed_state = None
|
||||
if self.no_cuda:
|
||||
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
|
||||
|
Loading…
Reference in New Issue
Block a user