mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix key mapping for VLMs (#39029)
* fix key mapping for VLMs * use __mro__ instead * update key mapping in save_pretrained
This commit is contained in:
parent
3457e8e73e
commit
d53518c5f2
@ -3746,7 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
module_map[name + f".{key}"] = module
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
|
||||
if any(
|
||||
allowed_name in class_name.__name__.lower()
|
||||
for class_name in self.__class__.__mro__[:-1]
|
||||
for allowed_name in VLMS
|
||||
):
|
||||
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
|
||||
|
||||
original_state_dict = {}
|
||||
@ -4402,7 +4406,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
|
||||
if key_mapping is None and any(
|
||||
allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
|
||||
):
|
||||
key_mapping = cls._checkpoint_conversion_mapping
|
||||
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
|
Loading…
Reference in New Issue
Block a user