mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix resume fsdp (#23111)
* fix resume fsdp * fix rank 0 loading * fix style and quality
This commit is contained in:
parent
3b74889e8f
commit
adb0760b5f
@ -2114,7 +2114,7 @@ class Trainer:
|
||||
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
|
||||
|
||||
if not any(
|
||||
[os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]]
|
||||
os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]
|
||||
):
|
||||
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
|
||||
|
||||
@ -2364,6 +2364,12 @@ class Trainer:
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
self.optimizer.consolidate_state_dict()
|
||||
|
||||
if self.fsdp:
|
||||
# FSDP has a different interface for saving optimizer states.
|
||||
# Needs to be called on all ranks to gather all states.
|
||||
# full_optim_state_dict will be deprecated after Pytorch 2.2!
|
||||
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
@ -2388,7 +2394,11 @@ class Trainer:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
elif self.args.should_save and not self.deepspeed:
|
||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
if self.fsdp:
|
||||
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
else:
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
@ -2498,9 +2508,18 @@ class Trainer:
|
||||
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
|
||||
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
|
||||
map_location = self.args.device if self.args.world_size > 1 else "cpu"
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||
)
|
||||
if self.fsdp:
|
||||
full_osd = None
|
||||
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
|
||||
if self.args.process_index == 0:
|
||||
full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
|
||||
# call scatter_full_optim_state_dict on all ranks
|
||||
sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
|
||||
self.optimizer.load_state_dict(sharded_osd)
|
||||
else:
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
Loading…
Reference in New Issue
Block a user