mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Avoid overwrite existing local implementation when loading remote custom model (#38474)
* avoid overwrite existing local implementation when loading custom remote model Signed-off-by: Isotr0py <2037008807@qq.com> * update comments Signed-off-by: Isotr0py <2037008807@qq.com> --------- Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
8f630651b0
commit
0f833528c9
@ -442,8 +442,12 @@ class _BaseAutoModelClass:
|
|||||||
else:
|
else:
|
||||||
repo_id = config.name_or_path
|
repo_id = config.name_or_path
|
||||||
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
|
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
|
||||||
model_class.register_for_auto_class(auto_class=cls)
|
# This block handles the case where the user is loading a model with `trust_remote_code=True`
|
||||||
|
# but a library model exists with the same name. We don't want to override the autoclass
|
||||||
|
# mappings in this case, or all future loads of that model will be the remote code model.
|
||||||
|
if not has_local_code:
|
||||||
cls.register(config.__class__, model_class, exist_ok=True)
|
cls.register(config.__class__, model_class, exist_ok=True)
|
||||||
|
model_class.register_for_auto_class(auto_class=cls)
|
||||||
_ = kwargs.pop("code_revision", None)
|
_ = kwargs.pop("code_revision", None)
|
||||||
model_class = add_generation_mixin_to_remote_model(model_class)
|
model_class = add_generation_mixin_to_remote_model(model_class)
|
||||||
return model_class._from_config(config, **kwargs)
|
return model_class._from_config(config, **kwargs)
|
||||||
@ -579,6 +583,10 @@ class _BaseAutoModelClass:
|
|||||||
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
|
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
|
||||||
)
|
)
|
||||||
_ = hub_kwargs.pop("code_revision", None)
|
_ = hub_kwargs.pop("code_revision", None)
|
||||||
|
# This block handles the case where the user is loading a model with `trust_remote_code=True`
|
||||||
|
# but a library model exists with the same name. We don't want to override the autoclass
|
||||||
|
# mappings in this case, or all future loads of that model will be the remote code model.
|
||||||
|
if not has_local_code:
|
||||||
cls.register(config.__class__, model_class, exist_ok=True)
|
cls.register(config.__class__, model_class, exist_ok=True)
|
||||||
model_class.register_for_auto_class(auto_class=cls)
|
model_class.register_for_auto_class(auto_class=cls)
|
||||||
model_class = add_generation_mixin_to_remote_model(model_class)
|
model_class = add_generation_mixin_to_remote_model(model_class)
|
||||||
|
Loading…
Reference in New Issue
Block a user