Fix non-deterministic Megatron-LM checkpoint name (#24674)

Fix non-deterministic checkpoint name

`os.listdir`'s order is not deterministic, which is a problem when
querying the first listed file as in the code (`os.listdir(...)[0]`).

This can return a checkpoint name such as `distrib_optim.pt`, which does
not include desired information such as the saved arguments originally
given to Megatron-LM.
This commit is contained in:
janEbert 2023-07-11 20:55:04 +02:00 committed by GitHub
parent 33aafc26ee
commit aac4c79968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -291,8 +291,10 @@ def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank):
tp_state_dicts = []
for i in range(tp_size):
sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}"
checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir_name))[0]
checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
for checkpoint_name in ["model_optim_rng.pt", "model_rng.pt"]:
checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
if os.path.isfile(checkpoint_path):
break
state_dict = torch.load(checkpoint_path, map_location="cpu")
tp_state_dicts.append(state_dict)
return tp_state_dicts