Fix the key name for _load_rng_state under torch.cuda (#36138)

fix load key name for _load_rng_state under torch.cuda

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Wizyoung 2025-02-14 00:35:08 +08:00 committed by GitHub
parent bfe46c98b5
commit 12962fe84b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3125,7 +3125,7 @@ class Trainer:
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
if torch.cuda.is_available():
set_rng_state_for_device("GPU", torch.cuda, checkpoint_rng_state, is_distributed)
set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed)
if is_torch_npu_available():
set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
if is_torch_mlu_available():