mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Accumulate opt state dict on do_rank 0 (#11481)
This commit is contained in:
parent
1e8e06862f
commit
f4c9a7e62e
@ -1420,14 +1420,15 @@ class Trainer:
|
|||||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
# Consolidate the state dict on all processed of dp_rank 0
|
if smp.dp_rank() == 0:
|
||||||
opt_state_dict = self.optimizer.state_dict()
|
# Consolidate the state dict on all processed of dp_rank 0
|
||||||
# Save it and the scheduler on the main process
|
opt_state_dict = self.optimizer.state_dict()
|
||||||
if self.is_world_process_zero():
|
# Save it and the scheduler on the main process
|
||||||
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
|
if self.is_world_process_zero():
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
reissue_pt_warnings(caught_warnings)
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
elif self.is_world_process_zero() and not self.deepspeed:
|
elif self.is_world_process_zero() and not self.deepspeed:
|
||||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
|
Loading…
Reference in New Issue
Block a user