diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9d42674f209..2ff4f93fdfc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2740,13 +2740,17 @@ class Trainer: self._save(output_dir, state_dict=state_dict) elif self.is_deepspeed_enabled: # this takes care of everything as long as we aren't under zero3 - if self.args.should_save: - self._save(output_dir) + if self.args.should_save and not is_deepspeed_zero3_enabled(): + if version.parse(accelerate_version) <= version.parse("0.20.3"): + raise ValueError("Install Accelerate from main branch") + state_dict = self.accelerator.get_state_dict(self.deepspeed) + + self._save(output_dir, state_dict=state_dict) if is_deepspeed_zero3_enabled(): # It's too complicated to try to override different places where the weights dump gets # saved, so since under zero3 the file is bogus, simply delete it. The user should - # either user deepspeed checkpoint to resume or to recover full weights use + # either use deepspeed checkpoint to resume or to recover full weights use # zero_to_fp32.py stored in the checkpoint. if self.args.should_save: file = os.path.join(output_dir, WEIGHTS_NAME)