Index RNG states by global rank in saves (#17852)

This commit is contained in:
Sylvain Gugger 2022-06-23 12:53:50 -04:00 committed by GitHub
parent 7cf52a49de
commit 7c1b91281f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1922,12 +1922,12 @@ class Trainer:
if checkpoint is None:
return
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank != -1:
rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth")
if self.args.world_size > 1:
process_index = self.args.process_index
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file):
logger.info(
f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that "
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
@ -2067,11 +2067,10 @@ class Trainer:
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank == -1:
if self.args.world_size <= 1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)