deepspeed z1/z2 state dict fix (#24489)

* deepspeed z2/z1 state_dict bloating fix

* update

* version check
This commit is contained in:
Sourab Mangrulkar 2023-06-26 17:45:37 +05:30 committed by GitHub
parent c8aff1d3e6
commit 195a9e5bdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2740,13 +2740,17 @@ class Trainer:
self._save(output_dir, state_dict=state_dict)
elif self.is_deepspeed_enabled:
# this takes care of everything as long as we aren't under zero3
if self.args.should_save:
self._save(output_dir)
if self.args.should_save and not is_deepspeed_zero3_enabled():
if version.parse(accelerate_version) <= version.parse("0.20.3"):
raise ValueError("Install Accelerate from main branch")
state_dict = self.accelerator.get_state_dict(self.deepspeed)
self._save(output_dir, state_dict=state_dict)
if is_deepspeed_zero3_enabled():
# It's too complicated to try to override different places where the weights dump gets
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either user deepspeed checkpoint to resume or to recover full weights use
# either use deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.args.should_save:
file = os.path.join(output_dir, WEIGHTS_NAME)