mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
deepspeed z1/z2 state dict fix (#24489)
* deepspeed z2/z1 state_dict bloating fix * update * version check
This commit is contained in:
parent
c8aff1d3e6
commit
195a9e5bdb
@ -2740,13 +2740,17 @@ class Trainer:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
elif self.is_deepspeed_enabled:
|
||||
# this takes care of everything as long as we aren't under zero3
|
||||
if self.args.should_save:
|
||||
self._save(output_dir)
|
||||
if self.args.should_save and not is_deepspeed_zero3_enabled():
|
||||
if version.parse(accelerate_version) <= version.parse("0.20.3"):
|
||||
raise ValueError("Install Accelerate from main branch")
|
||||
state_dict = self.accelerator.get_state_dict(self.deepspeed)
|
||||
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
# It's too complicated to try to override different places where the weights dump gets
|
||||
# saved, so since under zero3 the file is bogus, simply delete it. The user should
|
||||
# either user deepspeed checkpoint to resume or to recover full weights use
|
||||
# either use deepspeed checkpoint to resume or to recover full weights use
|
||||
# zero_to_fp32.py stored in the checkpoint.
|
||||
if self.args.should_save:
|
||||
file = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
|
Loading…
Reference in New Issue
Block a user