fix ds z3 checkpointing when stage3_gather_16bit_weights_on_model_save=False (#25817)

* fix ds z3 checkpointing when  `stage3_gather_16bit_weights_on_model_save=False`

* refactoring
This commit is contained in:
Sourab Mangrulkar 2023-08-31 15:17:53 +05:30 committed by GitHub
parent f8468b4fac
commit e95bcaeef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 6 deletions

View File

@ -93,6 +93,7 @@ from .trainer_pt_utils import (
nested_numpify,
nested_xla_mesh_reduce,
reissue_pt_warnings,
remove_dummy_checkpoint,
)
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
@ -2780,12 +2781,8 @@ class Trainer:
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.is_fsdp_enabled:
# remove the dummy state_dict saved above
if self.args.should_save:
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
elif self.is_deepspeed_enabled:
@ -2801,6 +2798,9 @@ class Trainer:
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
)
self._save(output_dir, state_dict={})
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model_wrapped.save_checkpoint(output_dir)
elif self.args.should_save:

View File

@ -1089,6 +1089,14 @@ def get_module_class_from_name(module, name):
return module_class
def remove_dummy_checkpoint(is_main_process, output_dir, filenames):
if is_main_process:
for filename in filenames:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp