diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 90ddd68c269..7148da3dd12 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -14,6 +14,7 @@ import importlib import inspect +import re import warnings from typing import Any, Dict, List, Optional, Union @@ -149,6 +150,7 @@ class PeftAdapterMixin: # peft only supports low_cpu_mem_usage starting from v0.13.0 peft_load_kwargs = {} + key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None if low_cpu_mem_usage: min_version_lcmu = "0.13.0" if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu): @@ -233,6 +235,13 @@ class PeftAdapterMixin: new_key = key[len(prefix) :] else: new_key = key + + if key_mapping: + for pattern, replacement in key_mapping.items(): + new_key, n_replace = re.subn(pattern, replacement, new_key) + # Early exit of the loop + if n_replace > 0: + break processed_adapter_state_dict[new_key] = value # Load state dict diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e1668ffd7cc..9e30fbb43a1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4778,6 +4778,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi model.hf_quantizer = hf_quantizer if _adapter_model_path is not None: + adapter_kwargs["key_mapping"] = key_mapping model.load_adapter( _adapter_model_path, adapter_name=adapter_name,