mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Support save/load ckpt for XLA FSDP (#32311)
* Support save/load ckpt for XLA FSDP * Fix bug for save * Fix style * reserve sharded ckpt and better file naming * minor fix Co-authored-by: Zach Mueller <muellerzr@gmail.com> * add is_fsdp_xla_v1_enabled --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
parent
f1b720ed62
commit
8a4857c0db
@ -702,6 +702,7 @@ class Trainer:
|
||||
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
|
||||
self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled
|
||||
|
||||
def _activate_neftune(self, model):
|
||||
r"""
|
||||
@ -3002,7 +3003,20 @@ class Trainer:
|
||||
def _save_optimizer_and_scheduler(self, output_dir):
|
||||
if is_torch_xla_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
if self.is_fsdp_xla_v1_enabled:
|
||||
optm = {
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"shard_metadata": self.model.get_shard_metadata(),
|
||||
}
|
||||
xm.save(
|
||||
optm,
|
||||
os.path.join(
|
||||
output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
|
||||
),
|
||||
master_only=False,
|
||||
)
|
||||
else:
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
@ -3080,11 +3094,26 @@ class Trainer:
|
||||
)
|
||||
)
|
||||
)
|
||||
checkpoint_file_exists = (
|
||||
glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
|
||||
if self.is_fsdp_xla_v1_enabled
|
||||
else checkpoint_file_exists
|
||||
)
|
||||
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
|
||||
# Load in optimizer and scheduler states
|
||||
if is_torch_xla_available():
|
||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
||||
if self.is_fsdp_xla_v1_enabled:
|
||||
optimizer_state = torch.load(
|
||||
os.path.join(
|
||||
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
|
||||
),
|
||||
map_location="cpu",
|
||||
)
|
||||
# We only need `optimizer` when resuming from checkpoint
|
||||
optimizer_state = optimizer_state["optimizer"]
|
||||
else:
|
||||
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
@ -3499,7 +3528,7 @@ class Trainer:
|
||||
model = self.model
|
||||
xm.mark_step()
|
||||
|
||||
if xm.is_master_ordinal():
|
||||
if xm.is_master_ordinal(local=False):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@ -3507,7 +3536,40 @@ class Trainer:
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
supported_classes = (PushToHubMixin,)
|
||||
xm.rendezvous("saving_checkpoint")
|
||||
if not isinstance(model, supported_classes):
|
||||
if self.is_fsdp_xla_v1_enabled:
|
||||
ckpt = {
|
||||
"model": model.state_dict(),
|
||||
"shard_metadata": model.get_shard_metadata(),
|
||||
}
|
||||
ckpt_path = os.path.join(
|
||||
output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}"
|
||||
)
|
||||
# All ranks save sharded checkpoint
|
||||
xm.save(ckpt, ckpt_path, master_only=False)
|
||||
# Make sure all ranks have saved checkpoints
|
||||
xm.rendezvous("save_full_checkpoints")
|
||||
# Master save full checkpoint
|
||||
if self.args.should_save:
|
||||
from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
|
||||
|
||||
full_state_dict, _ = consolidate_sharded_model_checkpoints(
|
||||
ckpt_prefix=os.path.join(output_dir, ""),
|
||||
ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}",
|
||||
save_model=False,
|
||||
)
|
||||
model = model.module.module
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
if isinstance(unwrapped_model, supported_classes):
|
||||
unwrapped_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=full_state_dict,
|
||||
save_function=xm.save,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
elif not isinstance(model, supported_classes):
|
||||
if isinstance(self.accelerator.unwrap_model(model), supported_classes):
|
||||
self.accelerator.unwrap_model(model).save_pretrained(
|
||||
output_dir,
|
||||
|
Loading…
Reference in New Issue
Block a user