Avoid overwrite existing local implementation when loading remote custom model (#38474)

* avoid overwrite existing local implementation when loading custom remote model

Signed-off-by: Isotr0py <2037008807@qq.com>

* update comments

Signed-off-by: Isotr0py <2037008807@qq.com>

---------

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-06-05 20:54:40 +08:00 committed by GitHub
parent 8f630651b0
commit 0f833528c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -442,8 +442,12 @@ class _BaseAutoModelClass:
else: else:
repo_id = config.name_or_path repo_id = config.name_or_path
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
model_class.register_for_auto_class(auto_class=cls) # This block handles the case where the user is loading a model with `trust_remote_code=True`
# but a library model exists with the same name. We don't want to override the autoclass
# mappings in this case, or all future loads of that model will be the remote code model.
if not has_local_code:
cls.register(config.__class__, model_class, exist_ok=True) cls.register(config.__class__, model_class, exist_ok=True)
model_class.register_for_auto_class(auto_class=cls)
_ = kwargs.pop("code_revision", None) _ = kwargs.pop("code_revision", None)
model_class = add_generation_mixin_to_remote_model(model_class) model_class = add_generation_mixin_to_remote_model(model_class)
return model_class._from_config(config, **kwargs) return model_class._from_config(config, **kwargs)
@ -579,6 +583,10 @@ class _BaseAutoModelClass:
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
) )
_ = hub_kwargs.pop("code_revision", None) _ = hub_kwargs.pop("code_revision", None)
# This block handles the case where the user is loading a model with `trust_remote_code=True`
# but a library model exists with the same name. We don't want to override the autoclass
# mappings in this case, or all future loads of that model will be the remote code model.
if not has_local_code:
cls.register(config.__class__, model_class, exist_ok=True) cls.register(config.__class__, model_class, exist_ok=True)
model_class.register_for_auto_class(auto_class=cls) model_class.register_for_auto_class(auto_class=cls)
model_class = add_generation_mixin_to_remote_model(model_class) model_class = add_generation_mixin_to_remote_model(model_class)