🚨 🚨 Fix custom code saving (#37716)

* Firstly: Better detection of when we're a custom class

* Trigger tests

* Let's break everything

* make fixup

* fix mistaken line doubling

* Let's try to get rid of it from config classes at least

* Let's try to get rid of it from config classes at least

* Fixup image processor

* no more circular import

* Let's go back to setting `_auto_class` again

* Let's go back to setting `_auto_class` again

* stash commit

* Revert the irrelevant changes until we figure out AutoConfig

* Change tests since we're breaking expectations

* make fixup

* do the same for all custom classes

* Cleanup for feature extractor tests

* Cleanup tokenization tests too

* typo

* Fix tokenizer tests

* make fixup

* fix image processor test

* make fixup

* Remove warning from register_for_auto_class

* Stop adding model info to auto map entirely

* Remove todo

* Remove the other todo

* Let's start slapping _auto_class on models why not

* Let's start slapping _auto_class on models why not

* Make sure the tests know what's up

* Make sure the tests know what's up

* Completely remove add_model_info_to_*

* Start adding _auto_class to models

* Start adding _auto_class to models

* Add a flaky decorator

* Add a flaky decorator and import

* stash commit

* More message cleanup

* make fixup

* fix indent

* Fix trust_remote_code prompts

* make fixup

* correct indentation

* Reincorporate changes into dynamic_module_utils

* Update call to trust_remote_code

* make fixup

* Fix video processors too

* Fix video processors too

* Remove is_flaky additions

* make fixup
This commit is contained in:
Matt 2025-05-26 17:37:30 +01:00 committed by GitHub
parent 701caef704
commit ba6d72226d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 120 additions and 231 deletions

View File

@ -28,8 +28,6 @@ from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
from .utils import (
CONFIG_NAME,
PushToHubMixin,
add_model_info_to_auto_map,
add_model_info_to_custom_pipelines,
cached_file,
copy_func,
download_url,
@ -713,15 +711,6 @@ class PretrainedConfig(PushToHubMixin):
else:
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
if "auto_map" in config_dict and not is_local:
config_dict["auto_map"] = add_model_info_to_auto_map(
config_dict["auto_map"], pretrained_model_name_or_path
)
if "custom_pipelines" in config_dict and not is_local:
config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
config_dict["custom_pipelines"], pretrained_model_name_or_path
)
# timm models are not saved with the model_type in the config file
if "model_type" not in config_dict and is_timm_config_dict(config_dict):
config_dict["model_type"] = "timm_wrapper"
@ -1044,11 +1033,7 @@ class PretrainedConfig(PushToHubMixin):
Register this class with a given auto class. This should only be used for custom configurations as the ones in
the library are already mapped with `AutoConfig`.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):

View File

@ -667,7 +667,9 @@ def _raise_timeout_error(signum, frame):
TIME_OUT_REMOTE_CODE = 15
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None):
def resolve_trust_remote_code(
trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None, upstream_repo=None
):
"""
Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading
it.
@ -688,11 +690,25 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has
Returns:
The resolved `trust_remote_code` value.
"""
# Originally, `trust_remote_code` was used to load models with custom code.
error_message = (
error_message
or f"The repository `{model_name}` contains custom code which must be executed to correctly load the model."
)
if error_message is None:
if upstream_repo is not None:
error_message = (
f"The repository {model_name} references custom code contained in {upstream_repo} which "
f"must be executed to correctly load the model. You can inspect the repository "
f"content at https://hf.co/{upstream_repo} .\n"
)
elif os.path.isdir(model_name):
error_message = (
f"The repository {model_name} contains custom code which must be executed "
f"to correctly load the model. You can inspect the repository "
f"content at {os.path.abspath(model_name)} .\n"
)
else:
error_message = (
f"The repository {model_name} contains custom code which must be executed "
f"to correctly load the model. You can inspect the repository "
f"content at https://hf.co/{model_name} .\n"
)
if trust_remote_code is None:
if has_local_code:

View File

@ -29,8 +29,6 @@ from .utils import (
FEATURE_EXTRACTOR_NAME,
PushToHubMixin,
TensorType,
add_model_info_to_auto_map,
add_model_info_to_custom_pipelines,
cached_file,
copy_func,
download_url,
@ -551,16 +549,6 @@ class FeatureExtractionMixin(PushToHubMixin):
f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
)
if not is_local:
if "auto_map" in feature_extractor_dict:
feature_extractor_dict["auto_map"] = add_model_info_to_auto_map(
feature_extractor_dict["auto_map"], pretrained_model_name_or_path
)
if "custom_pipelines" in feature_extractor_dict:
feature_extractor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
feature_extractor_dict["custom_pipelines"], pretrained_model_name_or_path
)
return feature_extractor_dict, kwargs
@classmethod
@ -673,11 +661,7 @@ class FeatureExtractionMixin(PushToHubMixin):
Register this class with a given auto class. This should only be used for custom feature extractors as the ones
in the library are already mapped with `AutoFeatureExtractor`.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):

View File

@ -28,8 +28,6 @@ from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .utils import (
IMAGE_PROCESSOR_NAME,
PushToHubMixin,
add_model_info_to_auto_map,
add_model_info_to_custom_pipelines,
cached_file,
copy_func,
download_url,
@ -380,14 +378,6 @@ class ImageProcessingMixin(PushToHubMixin):
logger.info(
f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
)
if "auto_map" in image_processor_dict:
image_processor_dict["auto_map"] = add_model_info_to_auto_map(
image_processor_dict["auto_map"], pretrained_model_name_or_path
)
if "custom_pipelines" in image_processor_dict:
image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
image_processor_dict["custom_pipelines"], pretrained_model_name_or_path
)
return image_processor_dict, kwargs
@ -508,11 +498,7 @@ class ImageProcessingMixin(PushToHubMixin):
Register this class with a given auto class. This should only be used for custom image processors as the ones
in the library are already mapped with `AutoImageProcessor `.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):

View File

@ -1218,11 +1218,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
Register this class with a given auto class. This should only be used for custom models as the ones in the
library are already mapped with an auto class.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):

View File

@ -3229,11 +3229,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
Register this class with a given auto class. This should only be used for custom models as the ones in the
library are already mapped with an auto class.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):

View File

@ -5321,11 +5321,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
Register this class with a given auto class. This should only be used for custom models as the ones in the
library are already mapped with an auto class.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`):

View File

@ -420,17 +420,23 @@ class _BaseAutoModelClass:
trust_remote_code = kwargs.pop("trust_remote_code", None)
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, config._name_or_path, has_local_code, has_remote_code
)
if has_remote_code:
class_ref = config.auto_map[cls.__name__]
if "--" in class_ref:
upstream_repo = class_ref.split("--")[0]
else:
upstream_repo = None
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo
)
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__]
if "--" in class_ref:
repo_id, class_ref = class_ref.split("--")
else:
repo_id = config.name_or_path
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
model_class.register_for_auto_class(auto_class=cls)
cls.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
model_class = add_generation_mixin_to_remote_model(model_class)
@ -545,8 +551,17 @@ class _BaseAutoModelClass:
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
upstream_repo = None
if has_remote_code:
class_ref = config.auto_map[cls.__name__]
if "--" in class_ref:
upstream_repo = class_ref.split("--")[0]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
trust_remote_code,
pretrained_model_name_or_path,
has_local_code,
has_remote_code,
upstream_repo=upstream_repo,
)
kwargs["trust_remote_code"] = trust_remote_code
@ -554,12 +569,12 @@ class _BaseAutoModelClass:
kwargs["adapter_kwargs"] = adapter_kwargs
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__]
model_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
)
_ = hub_kwargs.pop("code_revision", None)
cls.register(config.__class__, model_class, exist_ok=True)
model_class.register_for_auto_class(auto_class=cls)
model_class = add_generation_mixin_to_remote_model(model_class)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs

View File

@ -15,7 +15,6 @@
"""Auto Config class."""
import importlib
import os
import re
import warnings
from collections import OrderedDict
@ -1155,17 +1154,21 @@ class AutoConfig:
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code:
class_ref = config_dict["auto_map"]["AutoConfig"]
if "--" in class_ref:
upstream_repo = class_ref.split("--")[0]
else:
upstream_repo = None
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
)
if has_remote_code and trust_remote_code:
class_ref = config_dict["auto_map"]["AutoConfig"]
config_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs
)
if os.path.isdir(pretrained_model_name_or_path):
config_class.register_for_auto_class()
config_class.register_for_auto_class()
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "model_type" in config_dict:
try:

View File

@ -371,17 +371,21 @@ class AutoFeatureExtractor:
has_remote_code = feature_extractor_auto_map is not None
has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code:
if "--" in feature_extractor_auto_map:
upstream_repo = feature_extractor_auto_map.split("--")[0]
else:
upstream_repo = None
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
)
if has_remote_code and trust_remote_code:
feature_extractor_class = get_class_from_dynamic_module(
feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
feature_extractor_class.register_for_auto_class()
feature_extractor_class.register_for_auto_class()
return feature_extractor_class.from_dict(config_dict, **kwargs)
elif feature_extractor_class is not None:
return feature_extractor_class.from_dict(config_dict, **kwargs)

View File

@ -539,26 +539,29 @@ class AutoImageProcessor:
has_remote_code = image_processor_auto_map is not None
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
# In some configs, only the slow image processor class is stored
image_processor_auto_map = (image_processor_auto_map, None)
if has_remote_code:
if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
# In some configs, only the slow image processor class is stored
image_processor_auto_map = (image_processor_auto_map, None)
if use_fast and image_processor_auto_map[1] is not None:
class_ref = image_processor_auto_map[1]
else:
class_ref = image_processor_auto_map[0]
if "--" in class_ref:
upstream_repo = class_ref.split("--")[0]
else:
upstream_repo = None
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
)
if has_remote_code and trust_remote_code:
if not use_fast and image_processor_auto_map[1] is not None:
_warning_fast_image_processor_available(image_processor_auto_map[1])
if use_fast and image_processor_auto_map[1] is not None:
class_ref = image_processor_auto_map[1]
else:
class_ref = image_processor_auto_map[0]
image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
image_processor_class.register_for_auto_class()
image_processor_class.register_for_auto_class()
return image_processor_class.from_dict(config_dict, **kwargs)
elif image_processor_class is not None:
return image_processor_class.from_dict(config_dict, **kwargs)

View File

@ -17,7 +17,6 @@
import importlib
import inspect
import json
import os
import warnings
from collections import OrderedDict
@ -358,17 +357,21 @@ class AutoProcessor:
has_remote_code = processor_auto_map is not None
has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code:
if "--" in processor_auto_map:
upstream_repo = processor_auto_map.split("--")[0]
else:
upstream_repo = None
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
)
if has_remote_code and trust_remote_code:
processor_class = get_class_from_dynamic_module(
processor_auto_map, pretrained_model_name_or_path, **kwargs
)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
processor_class.register_for_auto_class()
processor_class.register_for_auto_class()
return processor_class.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)

View File

@ -982,19 +982,23 @@ class AutoTokenizer:
or tokenizer_class_from_name(config_tokenizer_class + "Fast") is not None
)
)
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
if has_remote_code:
if use_fast and tokenizer_auto_map[1] is not None:
class_ref = tokenizer_auto_map[1]
else:
class_ref = tokenizer_auto_map[0]
if "--" in class_ref:
upstream_repo = class_ref.split("--")[0]
else:
upstream_repo = None
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
)
if has_remote_code and trust_remote_code:
tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
tokenizer_class.register_for_auto_class()
tokenizer_class.register_for_auto_class()
return tokenizer_class.from_pretrained(
pretrained_model_name_or_path, *inputs, trust_remote_code=trust_remote_code, **kwargs
)

View File

@ -339,8 +339,7 @@ class AutoVideoProcessor:
class_ref = video_processor_auto_map
video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
video_processor_class.register_for_auto_class()
video_processor_class.register_for_auto_class()
return video_processor_class.from_dict(config_dict, **kwargs)
elif video_processor_class is not None:
return video_processor_class.from_dict(config_dict, **kwargs)

View File

@ -54,8 +54,6 @@ from .utils import (
PROCESSOR_NAME,
PushToHubMixin,
TensorType,
add_model_info_to_auto_map,
add_model_info_to_custom_pipelines,
cached_file,
copy_func,
direct_transformers_import,
@ -938,16 +936,6 @@ class ProcessorMixin(PushToHubMixin):
if "chat_template" in kwargs:
processor_dict["chat_template"] = kwargs.pop("chat_template")
if not is_local:
if "auto_map" in processor_dict:
processor_dict["auto_map"] = add_model_info_to_auto_map(
processor_dict["auto_map"], pretrained_model_name_or_path
)
if "custom_pipelines" in processor_dict:
processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
processor_dict["custom_pipelines"], pretrained_model_name_or_path
)
return processor_dict, kwargs
@classmethod
@ -1192,11 +1180,7 @@ class ProcessorMixin(PushToHubMixin):
Register this class with a given auto class. This should only be used for custom feature extractors as the ones
in the library are already mapped with `AutoProcessor`.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`):

View File

@ -43,8 +43,6 @@ from .utils import (
PushToHubMixin,
TensorType,
add_end_docstrings,
add_model_info_to_auto_map,
add_model_info_to_custom_pipelines,
cached_file,
copy_func,
download_url,
@ -2116,13 +2114,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# For backward compatibility with odl format.
if isinstance(init_kwargs["auto_map"], (tuple, list)):
init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]}
init_kwargs["auto_map"] = add_model_info_to_auto_map(
init_kwargs["auto_map"], pretrained_model_name_or_path
)
if "custom_pipelines" in init_kwargs:
init_kwargs["custom_pipelines"] = add_model_info_to_custom_pipelines(
init_kwargs["custom_pipelines"], pretrained_model_name_or_path
)
if config_tokenizer_class is None:
# Matt: This entire block is only used to decide if the tokenizer class matches the class in the repo.
@ -3973,11 +3964,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
Register this class with a given auto class. This should only be used for custom tokenizers as the ones in the
library are already mapped with `AutoTokenizer`.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoTokenizer"`):

View File

@ -50,8 +50,6 @@ from .generic import (
ModelOutput,
PaddingStrategy,
TensorType,
add_model_info_to_auto_map,
add_model_info_to_custom_pipelines,
cached_property,
can_return_loss,
can_return_tuple,

View File

@ -727,32 +727,6 @@ def tensor_size(array):
raise ValueError(f"Type not supported for tensor_size: {type(array)}.")
def add_model_info_to_auto_map(auto_map, repo_id):
"""
Adds the information of the repo_id to a given auto map.
"""
for key, value in auto_map.items():
if isinstance(value, (tuple, list)):
auto_map[key] = [f"{repo_id}--{v}" if (v is not None and "--" not in v) else v for v in value]
elif value is not None and "--" not in value:
auto_map[key] = f"{repo_id}--{value}"
return auto_map
def add_model_info_to_custom_pipelines(custom_pipeline, repo_id):
"""
Adds the information of the repo_id to a given custom pipeline.
"""
# {custom_pipelines : {task: {"impl": "path.to.task"},...} }
for task in custom_pipeline.keys():
if "impl" in custom_pipeline[task]:
module = custom_pipeline[task]["impl"]
if "--" not in module:
custom_pipeline[task]["impl"] = f"{repo_id}--{module}"
return custom_pipeline
def infer_framework(model_class):
"""
Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant

View File

@ -36,8 +36,6 @@ from .processing_utils import Unpack, VideosKwargs
from .utils import (
VIDEO_PROCESSOR_NAME,
TensorType,
add_model_info_to_auto_map,
add_model_info_to_custom_pipelines,
add_start_docstrings,
cached_file,
copy_func,
@ -629,16 +627,6 @@ class BaseVideoProcessor(BaseImageProcessorFast):
logger.info(
f"loading configuration file {video_processor_file} from cache at {resolved_video_processor_file}"
)
if not is_local:
if "auto_map" in video_processor_dict:
video_processor_dict["auto_map"] = add_model_info_to_auto_map(
video_processor_dict["auto_map"], pretrained_model_name_or_path
)
if "custom_pipelines" in video_processor_dict:
video_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
video_processor_dict["custom_pipelines"], pretrained_model_name_or_path
)
return video_processor_dict, kwargs
@classmethod

View File

@ -122,19 +122,11 @@ class AutoConfigTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir)
reloaded_config = AutoConfig.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "configuration.py"))) # Assert we saved config code
# Assert we're pointing at local code and not another remote repo
self.assertEqual(reloaded_config.auto_map["AutoConfig"], "configuration.NewModelConfig")
self.assertEqual(reloaded_config.__class__.__name__, "NewModelConfig")
# The configuration file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the configuration file is not changed.
# Test the dynamic module is loaded only once if the configuration file is not changed.
self.assertIs(config.__class__, reloaded_config.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_config = AutoConfig.from_pretrained(
"hf-internal-testing/test_dynamic_model", trust_remote_code=True, force_download=True
)
self.assertIsNot(config.__class__, reloaded_config.__class__)
def test_from_pretrained_dynamic_config_conflict(self):
class NewModelConfigLocal(BertConfig):
model_type = "new-model"

View File

@ -13,6 +13,7 @@
# limitations under the License.
import json
import os
import sys
import tempfile
import unittest
@ -125,19 +126,12 @@ class AutoFeatureExtractorTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(tmp_dir)
reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "feature_extractor.py"))) # Assert we saved code
self.assertEqual(
reloaded_feature_extractor.auto_map["AutoFeatureExtractor"], "feature_extractor.NewFeatureExtractor"
)
self.assertEqual(reloaded_feature_extractor.__class__.__name__, "NewFeatureExtractor")
# The feature extractor file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(feature_extractor.__class__, reloaded_feature_extractor.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True, force_download=True
)
self.assertIsNot(feature_extractor.__class__, reloaded_feature_extractor.__class__)
def test_new_feature_extractor_registration(self):
try:
AutoConfig.register("custom", CustomConfig)

View File

@ -13,6 +13,7 @@
# limitations under the License.
import json
import os
import sys
import tempfile
import unittest
@ -190,13 +191,12 @@ class AutoImageProcessorTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained(tmp_dir)
reloaded_image_processor = AutoImageProcessor.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_processor.py"))) # Assert we saved custom code
self.assertEqual(
reloaded_image_processor.auto_map["AutoImageProcessor"], "image_processor.NewImageProcessor"
)
self.assertEqual(reloaded_image_processor.__class__.__name__, "NewImageProcessor")
# The image processor file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(image_processor.__class__, reloaded_image_processor.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_image_processor = AutoImageProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True, force_download=True

View File

@ -332,11 +332,6 @@ class AutoModelTest(unittest.TestCase):
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# The model file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(model.__class__, reloaded_model.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_model = AutoModel.from_pretrained(
"hf-internal-testing/test_dynamic_model", trust_remote_code=True, force_download=True
@ -362,11 +357,6 @@ class AutoModelTest(unittest.TestCase):
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# The model file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(model.__class__, reloaded_model.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_model = AutoModel.from_pretrained(
"hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True, force_download=True

View File

@ -342,17 +342,20 @@ class AutoTokenizerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
reloaded_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, trust_remote_code=True, use_fast=False)
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, "tokenization.py"))
) # Assert we saved tokenizer code
self.assertEqual(reloaded_tokenizer._auto_class, "AutoTokenizer")
with open(os.path.join(tmp_dir, "tokenizer_config.json"), "r") as f:
tokenizer_config = json.load(f)
# Assert we're pointing at local code and not another remote repo
self.assertEqual(tokenizer_config["auto_map"]["AutoTokenizer"], ["tokenization.NewTokenizer", None])
self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer")
self.assertTrue(reloaded_tokenizer.special_attribute_present)
else:
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer")
# The tokenizer file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(tokenizer.__class__, reloaded_tokenizer.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True, force_download=True

View File

@ -174,17 +174,6 @@ class AutoVideoProcessorTest(unittest.TestCase):
reloaded_video_processor = AutoVideoProcessor.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_video_processor.__class__.__name__, "NewVideoProcessor")
# The image processor file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(video_processor.__class__, reloaded_video_processor.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_video_processor = AutoVideoProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_video_processor", trust_remote_code=True, force_download=True
)
self.assertIsNot(video_processor.__class__, reloaded_video_processor.__class__)
def test_new_video_processor_registration(self):
try:
AutoConfig.register("custom", CustomConfig)