Cuda rng_state_all is used when saving in distributed mode so same should also be used when loading (#23045)

cuda rng state should be all for distributed bc all were saved
This commit is contained in:
Shivam Shrirao 2023-04-28 18:58:01 +05:30 committed by GitHub
parent 521a8ffa53
commit 4d0ea3d269
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2327,10 +2327,10 @@ class Trainer:
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
else:
try:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"