mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix the key name for _load_rng_state under torch.cuda (#36138)
fix load key name for _load_rng_state under torch.cuda Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
bfe46c98b5
commit
12962fe84b
@ -3125,7 +3125,7 @@ class Trainer:
|
||||
|
||||
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
if torch.cuda.is_available():
|
||||
set_rng_state_for_device("GPU", torch.cuda, checkpoint_rng_state, is_distributed)
|
||||
set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed)
|
||||
if is_torch_npu_available():
|
||||
set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
|
||||
if is_torch_mlu_available():
|
||||
|
Loading…
Reference in New Issue
Block a user