fixes issue when saving fsdp via accelerate's FSDP plugin (#24446)

This commit is contained in:
Sourab Mangrulkar 2023-06-23 18:03:57 +05:30 committed by GitHub
parent 2898fd3968
commit a6f37f8879
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2322,7 +2322,7 @@ class Trainer:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.is_deepspeed_enabled:
# deepspeed.save_checkpoint above saves model/optim/sched
if self.fsdp:
if self.fsdp and not self.is_fsdp_enabled:
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
else:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))