mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Support saving only PEFT adapter in checkpoints when using PEFT + FSDP (#28297)
* Update trainer.py * Revert "Update trainer.py" This reverts commit 0557e2cc9effa3a41304322032239a3874b948a7. * Make trainer.py use adapter_only=True when using FSDP + PEFT * Support load_best_model with adapter_only=True * Ruff format * Inspect function args for save_ load_ fsdp utility functions and only pass adapter_only=True if they support it
This commit is contained in:
parent
da3c79b245
commit
a055d09e11
@ -223,6 +223,14 @@ def _is_peft_model(model):
|
||||
return False
|
||||
|
||||
|
||||
def _get_fsdp_ckpt_kwargs():
|
||||
# TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
|
||||
if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
|
||||
return {"adapter_only": True}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import optuna
|
||||
|
||||
@ -2129,7 +2137,13 @@ class Trainer:
|
||||
# release memory
|
||||
del state_dict
|
||||
elif self.is_fsdp_enabled:
|
||||
load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)
|
||||
load_fsdp_model(
|
||||
self.accelerator.state.fsdp_plugin,
|
||||
self.accelerator,
|
||||
model,
|
||||
resume_from_checkpoint,
|
||||
**_get_fsdp_ckpt_kwargs(),
|
||||
)
|
||||
else:
|
||||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
|
||||
@ -2182,7 +2196,11 @@ class Trainer:
|
||||
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
|
||||
elif self.is_fsdp_enabled:
|
||||
load_result = load_fsdp_model(
|
||||
self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
|
||||
self.accelerator.state.fsdp_plugin,
|
||||
self.accelerator,
|
||||
model,
|
||||
self.state.best_model_checkpoint,
|
||||
**_get_fsdp_ckpt_kwargs(),
|
||||
)
|
||||
elif (
|
||||
os.path.exists(best_model_path)
|
||||
@ -2504,7 +2522,9 @@ class Trainer:
|
||||
self.model_wrapped.save_checkpoint(output_dir)
|
||||
elif self.is_fsdp_enabled:
|
||||
# save fsdp specific ckpt for resuming from ckpt
|
||||
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
|
||||
save_fsdp_model(
|
||||
self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs()
|
||||
)
|
||||
save_fsdp_optimizer(
|
||||
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
|
||||
)
|
||||
@ -2598,6 +2618,7 @@ class Trainer:
|
||||
self.optimizer,
|
||||
self.model,
|
||||
checkpoint,
|
||||
**_get_fsdp_ckpt_kwargs(),
|
||||
)
|
||||
else:
|
||||
self.optimizer.load_state_dict(
|
||||
|
Loading…
Reference in New Issue
Block a user