mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Index RNG states by global rank in saves (#17852)
This commit is contained in:
parent
7cf52a49de
commit
7c1b91281f
@ -1922,12 +1922,12 @@ class Trainer:
|
||||
if checkpoint is None:
|
||||
return
|
||||
|
||||
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
|
||||
if local_rank != -1:
|
||||
rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth")
|
||||
if self.args.world_size > 1:
|
||||
process_index = self.args.process_index
|
||||
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
|
||||
if not os.path.isfile(rng_file):
|
||||
logger.info(
|
||||
f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that "
|
||||
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
|
||||
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
|
||||
)
|
||||
return
|
||||
@ -2067,11 +2067,10 @@ class Trainer:
|
||||
# not yet exist.
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
|
||||
if local_rank == -1:
|
||||
if self.args.world_size <= 1:
|
||||
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
|
||||
else:
|
||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
|
||||
|
||||
if self.args.push_to_hub:
|
||||
self._push_from_checkpoint(output_dir)
|
||||
|
Loading…
Reference in New Issue
Block a user