diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 7148da3dd12..7c9e37c2786 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -28,6 +28,7 @@ from ..utils import ( is_torch_available, logging, ) +from ..modeling_utils import VLMS if is_torch_available(): @@ -151,6 +152,8 @@ class PeftAdapterMixin: # peft only supports low_cpu_mem_usage starting from v0.13.0 peft_load_kwargs = {} key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None + if key_mapping is None and any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS): + key_mapping = self._checkpoint_conversion_mapping if low_cpu_mem_usage: min_version_lcmu = "0.13.0" if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9e30fbb43a1..a0f5180ac8b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4251,11 +4251,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi device_mesh = kwargs.pop("device_mesh", None) trust_remote_code = kwargs.pop("trust_remote_code", None) - # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model - if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS): - key_mapping = kwargs.pop("key_mapping", cls._checkpoint_conversion_mapping) - else: - key_mapping = kwargs.pop("key_mapping", None) + key_mapping = kwargs.pop("key_mapping", None) + # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model + if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS): + key_mapping = cls._checkpoint_conversion_mapping # Not used anymore -- remove them from the kwargs _ = kwargs.pop("resume_download", None)