Support saving only PEFT adapter in checkpoints when using PEFT + FSDP (#28297)

* Update trainer.py

* Revert "Update trainer.py"

This reverts commit 0557e2cc9effa3a41304322032239a3874b948a7.

* Make trainer.py use adapter_only=True when using FSDP + PEFT

* Support load_best_model with adapter_only=True

* Ruff format

* Inspect function args for save_ load_ fsdp utility functions and only pass adapter_only=True if they support it
This commit is contained in:
Ajay Patel 2024-01-29 12:10:15 -05:00 committed by GitHub
parent da3c79b245
commit a055d09e11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -223,6 +223,14 @@ def _is_peft_model(model):
return False
def _get_fsdp_ckpt_kwargs():
# TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
return {"adapter_only": True}
else:
return {}
if TYPE_CHECKING:
import optuna
@ -2129,7 +2137,13 @@ class Trainer:
# release memory
del state_dict
elif self.is_fsdp_enabled:
load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)
load_fsdp_model(
self.accelerator.state.fsdp_plugin,
self.accelerator,
model,
resume_from_checkpoint,
**_get_fsdp_ckpt_kwargs(),
)
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
@ -2182,7 +2196,11 @@ class Trainer:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
elif self.is_fsdp_enabled:
load_result = load_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
self.accelerator.state.fsdp_plugin,
self.accelerator,
model,
self.state.best_model_checkpoint,
**_get_fsdp_ckpt_kwargs(),
)
elif (
os.path.exists(best_model_path)
@ -2504,7 +2522,9 @@ class Trainer:
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled:
# save fsdp specific ckpt for resuming from ckpt
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
save_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs()
)
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
)
@ -2598,6 +2618,7 @@ class Trainer:
self.optimizer,
self.model,
checkpoint,
**_get_fsdp_ckpt_kwargs(),
)
else:
self.optimizer.load_state_dict(