mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
fix ds z3 checkpointing when stage3_gather_16bit_weights_on_model_save=False
(#25817)
* fix ds z3 checkpointing when `stage3_gather_16bit_weights_on_model_save=False` * refactoring
This commit is contained in:
parent
f8468b4fac
commit
e95bcaeef0
@ -93,6 +93,7 @@ from .trainer_pt_utils import (
|
||||
nested_numpify,
|
||||
nested_xla_mesh_reduce,
|
||||
reissue_pt_warnings,
|
||||
remove_dummy_checkpoint,
|
||||
)
|
||||
from .trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
@ -2780,12 +2781,8 @@ class Trainer:
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
if self.is_fsdp_enabled:
|
||||
# remove the dummy state_dict saved above
|
||||
if self.args.should_save:
|
||||
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]:
|
||||
file = os.path.join(output_dir, filename)
|
||||
if os.path.isfile(file):
|
||||
os.remove(file)
|
||||
# remove the dummy state_dict
|
||||
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
|
||||
|
||||
elif self.is_deepspeed_enabled:
|
||||
@ -2801,6 +2798,9 @@ class Trainer:
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
||||
" zero_to_fp32.py to recover weights"
|
||||
)
|
||||
self._save(output_dir, state_dict={})
|
||||
# remove the dummy state_dict
|
||||
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||
self.model_wrapped.save_checkpoint(output_dir)
|
||||
|
||||
elif self.args.should_save:
|
||||
|
@ -1089,6 +1089,14 @@ def get_module_class_from_name(module, name):
|
||||
return module_class
|
||||
|
||||
|
||||
def remove_dummy_checkpoint(is_main_process, output_dir, filenames):
|
||||
if is_main_process:
|
||||
for filename in filenames:
|
||||
file = os.path.join(output_dir, filename)
|
||||
if os.path.isfile(file):
|
||||
os.remove(file)
|
||||
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user