mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix non-deterministic Megatron-LM checkpoint name (#24674)
Fix non-deterministic checkpoint name `os.listdir`'s order is not deterministic, which is a problem when querying the first listed file as in the code (`os.listdir(...)[0]`). This can return a checkpoint name such as `distrib_optim.pt`, which does not include desired information such as the saved arguments originally given to Megatron-LM.
This commit is contained in:
parent
33aafc26ee
commit
aac4c79968
@ -291,8 +291,10 @@ def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank):
|
||||
tp_state_dicts = []
|
||||
for i in range(tp_size):
|
||||
sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}"
|
||||
checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir_name))[0]
|
||||
checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
|
||||
for checkpoint_name in ["model_optim_rng.pt", "model_rng.pt"]:
|
||||
checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
|
||||
if os.path.isfile(checkpoint_path):
|
||||
break
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
tp_state_dicts.append(state_dict)
|
||||
return tp_state_dicts
|
||||
|
Loading…
Reference in New Issue
Block a user