mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
add default mapping to peft integration
This commit is contained in:
parent
ce6ac53ac1
commit
608884960e
@ -28,6 +28,7 @@ from ..utils import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from ..modeling_utils import VLMS
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -151,6 +152,8 @@ class PeftAdapterMixin:
|
|||||||
# peft only supports low_cpu_mem_usage starting from v0.13.0
|
# peft only supports low_cpu_mem_usage starting from v0.13.0
|
||||||
peft_load_kwargs = {}
|
peft_load_kwargs = {}
|
||||||
key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
|
key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
|
||||||
|
if key_mapping is None and any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
|
||||||
|
key_mapping = self._checkpoint_conversion_mapping
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
min_version_lcmu = "0.13.0"
|
min_version_lcmu = "0.13.0"
|
||||||
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu):
|
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu):
|
||||||
|
@ -4251,11 +4251,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
device_mesh = kwargs.pop("device_mesh", None)
|
device_mesh = kwargs.pop("device_mesh", None)
|
||||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||||
|
|
||||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
|
||||||
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
|
|
||||||
key_mapping = kwargs.pop("key_mapping", cls._checkpoint_conversion_mapping)
|
|
||||||
else:
|
|
||||||
key_mapping = kwargs.pop("key_mapping", None)
|
key_mapping = kwargs.pop("key_mapping", None)
|
||||||
|
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||||
|
if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
|
||||||
|
key_mapping = cls._checkpoint_conversion_mapping
|
||||||
|
|
||||||
# Not used anymore -- remove them from the kwargs
|
# Not used anymore -- remove them from the kwargs
|
||||||
_ = kwargs.pop("resume_download", None)
|
_ = kwargs.pop("resume_download", None)
|
||||||
|
Loading…
Reference in New Issue
Block a user