diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 2dc3d0b0c0c..084f1472830 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -442,8 +442,12 @@ class _BaseAutoModelClass: else: repo_id = config.name_or_path model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) - model_class.register_for_auto_class(auto_class=cls) - cls.register(config.__class__, model_class, exist_ok=True) + # This block handles the case where the user is loading a model with `trust_remote_code=True` + # but a library model exists with the same name. We don't want to override the autoclass + # mappings in this case, or all future loads of that model will be the remote code model. + if not has_local_code: + cls.register(config.__class__, model_class, exist_ok=True) + model_class.register_for_auto_class(auto_class=cls) _ = kwargs.pop("code_revision", None) model_class = add_generation_mixin_to_remote_model(model_class) return model_class._from_config(config, **kwargs) @@ -579,8 +583,12 @@ class _BaseAutoModelClass: class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs ) _ = hub_kwargs.pop("code_revision", None) - cls.register(config.__class__, model_class, exist_ok=True) - model_class.register_for_auto_class(auto_class=cls) + # This block handles the case where the user is loading a model with `trust_remote_code=True` + # but a library model exists with the same name. We don't want to override the autoclass + # mappings in this case, or all future loads of that model will be the remote code model. + if not has_local_code: + cls.register(config.__class__, model_class, exist_ok=True) + model_class.register_for_auto_class(auto_class=cls) model_class = add_generation_mixin_to_remote_model(model_class) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs