bugfix: propage weight key_mapping to peft to fix 3.52 VLM renaming (#38627)

* propage key mapping to peft

* propage key mapping to peft

* make requested changes

* revert
This commit is contained in:
Manuel Faysse 2025-06-16 10:10:23 +02:00 committed by GitHub
parent 925da8ac56
commit ce6ac53ac1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 0 deletions

View File

@ -14,6 +14,7 @@
import importlib import importlib
import inspect import inspect
import re
import warnings import warnings
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@ -149,6 +150,7 @@ class PeftAdapterMixin:
# peft only supports low_cpu_mem_usage starting from v0.13.0 # peft only supports low_cpu_mem_usage starting from v0.13.0
peft_load_kwargs = {} peft_load_kwargs = {}
key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
if low_cpu_mem_usage: if low_cpu_mem_usage:
min_version_lcmu = "0.13.0" min_version_lcmu = "0.13.0"
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu): if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu):
@ -233,6 +235,13 @@ class PeftAdapterMixin:
new_key = key[len(prefix) :] new_key = key[len(prefix) :]
else: else:
new_key = key new_key = key
if key_mapping:
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
# Early exit of the loop
if n_replace > 0:
break
processed_adapter_state_dict[new_key] = value processed_adapter_state_dict[new_key] = value
# Load state dict # Load state dict

View File

@ -4778,6 +4778,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
model.hf_quantizer = hf_quantizer model.hf_quantizer = hf_quantizer
if _adapter_model_path is not None: if _adapter_model_path is not None:
adapter_kwargs["key_mapping"] = key_mapping
model.load_adapter( model.load_adapter(
_adapter_model_path, _adapter_model_path,
adapter_name=adapter_name, adapter_name=adapter_name,