Fix key mapping for VLMs (#39029)

* fix key mapping for VLMs

* use __mro__ instead

* update key mapping in save_pretrained
This commit is contained in:
BUI Van Tuan 2025-07-01 09:47:53 +02:00 committed by GitHub
parent 3457e8e73e
commit d53518c5f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3746,7 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict()
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
if any(
allowed_name in class_name.__name__.lower()
for class_name in self.__class__.__mro__[:-1]
for allowed_name in VLMS
):
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
original_state_dict = {}
@ -4402,7 +4406,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
key_mapping = kwargs.pop("key_mapping", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
if key_mapping is None and any(
allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
):
key_mapping = cls._checkpoint_conversion_mapping
# Not used anymore -- remove them from the kwargs