From dfb00bf64415b15220bc0d4b7ab40318f195a2d4 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 8 Nov 2021 15:28:25 -0500 Subject: [PATCH] Expand dynamic supported objects to configs and tokenizers (#14296) * Dynamic configs * Add config test * Better tests * Add tokenizer and test * Add to from_config * With save --- src/transformers/models/auto/auto_factory.py | 21 +++++- .../models/auto/configuration_auto.py | 30 +++++++- .../models/auto/tokenization_auto.py | 38 +++++++++- src/transformers/tokenization_utils_base.py | 3 + tests/test_configuration_common.py | 43 ++++++++++- tests/test_modeling_common.py | 74 ++++++++++++++++++- tests/test_tokenization_common.py | 73 +++++++++++++++++- 7 files changed, 272 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 4178f1dfaed..34124fc272d 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -378,7 +378,24 @@ class _BaseAutoModelClass: @classmethod def from_config(cls, config, **kwargs): - if type(config) in cls._model_mapping.keys(): + trust_remote_code = kwargs.pop("trust_remote_code", False) + if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: + if not trust_remote_code: + raise ValueError( + "Loading this model requires you to execute the modeling file in that repo " + "on your local machine. Make sure you have read the code there to avoid malicious use, then set " + "the option `trust_remote_code=True` to remove this error." + ) + if kwargs.get("revision", None) is None: + logger.warn( + "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " + "no malicious code has been contributed in a newer revision." + ) + class_ref = config.auto_map[cls.__name__] + module_file, class_name = class_ref.split(".") + model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs) + return model_class._from_config(config, **kwargs) + elif type(config) in cls._model_mapping.keys(): model_class = _get_model_class(config, cls._model_mapping) return model_class._from_config(config, **kwargs) @@ -394,7 +411,7 @@ class _BaseAutoModelClass: kwargs["_from_auto"] = True if not isinstance(config, PretrainedConfig): config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs ) if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: if not trust_remote_code: diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index e2185927a1b..12bf0578c3d 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -21,8 +21,12 @@ from typing import List, Union from ...configuration_utils import PretrainedConfig from ...file_utils import CONFIG_NAME +from ...utils import logging +from .dynamic import get_class_from_dynamic_module +logger = logging.get_logger(__name__) + CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here @@ -523,6 +527,10 @@ class AutoConfig: If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored. + trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it + will execute code present on the Hub on your local machine. kwargs(additional keyword arguments, `optional`): The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled @@ -555,8 +563,28 @@ class AutoConfig: {'foo': False} """ kwargs["_from_auto"] = True + kwargs["name_or_path"] = pretrained_model_name_or_path + trust_remote_code = kwargs.pop("trust_remote_code", False) config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) - if "model_type" in config_dict: + if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]: + if not trust_remote_code: + raise ValueError( + f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo " + "on your local machine. Make sure you have read the code there to avoid malicious use, then set " + "the option `trust_remote_code=True` to remove this error." + ) + if kwargs.get("revision", None) is None: + logger.warn( + "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to " + "ensure no malicious code has been contributed in a newer revision." + ) + class_ref = config_dict["auto_map"]["AutoConfig"] + module_file, class_name = class_ref.split(".") + config_class = get_class_from_dynamic_module( + pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs + ) + return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif "model_type" in config_dict: config_class = CONFIG_MAPPING[config_dict["model_type"]] return config_class.from_dict(config_dict, **kwargs) else: diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 1084464eb68..04ac20ddf6a 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -41,6 +41,7 @@ from .configuration_auto import ( model_type_to_module_name, replace_list_option_in_docstrings, ) +from .dynamic import get_class_from_dynamic_module logger = logging.get_logger(__name__) @@ -412,6 +413,10 @@ class AutoTokenizer: Whether or not to try to load the fast version of the tokenizer. tokenizer_type (:obj:`str`, `optional`): Tokenizer type to be loaded. + trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it + will execute code present on the Hub on your local machine. kwargs (additional keyword arguments, `optional`): Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, @@ -436,6 +441,7 @@ class AutoTokenizer: use_fast = kwargs.pop("use_fast", True) tokenizer_type = kwargs.pop("tokenizer_type", None) + trust_remote_code = kwargs.pop("trust_remote_code", False) # First, let's see whether the tokenizer_type is passed so that we can leverage it if tokenizer_type is not None: @@ -464,17 +470,45 @@ class AutoTokenizer: # Next, let's try to use the tokenizer_config file to get the tokenizer class. tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) config_tokenizer_class = tokenizer_config.get("tokenizer_class") + tokenizer_auto_map = tokenizer_config.get("auto_map") # If that did not work, let's try to use the config. if config_tokenizer_class is None: if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) config_tokenizer_class = config.tokenizer_class + if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: + tokenizer_auto_map = config.auto_map["AutoTokenizer"] # If we have the tokenizer class from the tokenizer config or the model config we're good! if config_tokenizer_class is not None: tokenizer_class = None - if use_fast and not config_tokenizer_class.endswith("Fast"): + if tokenizer_auto_map is not None: + if not trust_remote_code: + raise ValueError( + f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that repo " + "on your local machine. Make sure you have read the code there to avoid malicious use, then set " + "the option `trust_remote_code=True` to remove this error." + ) + if kwargs.get("revision", None) is None: + logger.warn( + "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " + "no malicious code has been contributed in a newer revision." + ) + + if use_fast and tokenizer_auto_map[1] is not None: + class_ref = tokenizer_auto_map[1] + else: + class_ref = tokenizer_auto_map[0] + + module_file, class_name = class_ref.split(".") + tokenizer_class = get_class_from_dynamic_module( + pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs + ) + + elif use_fast and not config_tokenizer_class.endswith("Fast"): tokenizer_class_candidate = f"{config_tokenizer_class}Fast" tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) if tokenizer_class is None: diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index d35f53cedc8..41ac3f19209 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1784,6 +1784,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. config_tokenizer_class = init_kwargs.get("tokenizer_class") init_kwargs.pop("tokenizer_class", None) + init_kwargs.pop("auto_map", None) saved_init_inputs = init_kwargs.pop("init_inputs", ()) if not init_inputs: init_inputs = saved_init_inputs @@ -2028,6 +2029,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast": tokenizer_class = tokenizer_class[:-4] tokenizer_config["tokenizer_class"] = tokenizer_class + if getattr(self, "_auto_map", None) is not None: + tokenizer_config["auto_map"] = self._auto_map with open(tokenizer_config_file, "w", encoding="utf-8") as f: f.write(json.dumps(tokenizer_config, ensure_ascii=False)) diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index da675e45ef4..66c7652f399 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -19,9 +19,9 @@ import os import tempfile import unittest -from huggingface_hub import delete_repo, login +from huggingface_hub import Repository, delete_repo, login from requests.exceptions import HTTPError -from transformers import BertConfig, GPT2Config, is_torch_available +from transformers import AutoConfig, BertConfig, GPT2Config, is_torch_available from transformers.configuration_utils import PretrainedConfig from transformers.testing_utils import PASS, USER, is_staging_test @@ -190,6 +190,23 @@ class ConfigTester(object): self.check_config_arguments_init() +class FakeConfig(PretrainedConfig): + def __init__(self, attribute=1, **kwargs): + self.attribute = attribute + super().__init__(**kwargs) + + +# Make sure this is synchronized with the config above. +FAKE_CONFIG_CODE = """ +from transformers import PretrainedConfig + +class FakeConfig(PretrainedConfig): + def __init__(self, attribute=1, **kwargs): + self.attribute = attribute + super().__init__(**kwargs) +""" + + @is_staging_test class ConfigPushToHubTester(unittest.TestCase): @classmethod @@ -208,6 +225,11 @@ class ConfigPushToHubTester(unittest.TestCase): except HTTPError: pass + try: + delete_repo(token=cls._token, name="test-dynamic-config") + except HTTPError: + pass + def test_push_to_hub(self): config = BertConfig( vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 @@ -238,6 +260,23 @@ class ConfigPushToHubTester(unittest.TestCase): if k != "transformers_version": self.assertEqual(v, getattr(new_config, k)) + def test_push_to_hub_dynamic_config(self): + config = FakeConfig(attribute=42) + config.auto_map = {"AutoConfig": "configuration.FakeConfig"} + + with tempfile.TemporaryDirectory() as tmp_dir: + repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token) + config.save_pretrained(tmp_dir) + with open(os.path.join(tmp_dir, "configuration.py"), "w") as f: + f.write(FAKE_CONFIG_CODE) + + repo.push_to_hub() + + new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True) + # Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module + self.assertEqual(new_config.__class__.__name__, "FakeConfig") + self.assertEqual(new_config.attribute, 42) + class ConfigTestUtils(unittest.TestCase): def test_config_from_string(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9dfe9275fda..6da867e2a7b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -30,7 +30,14 @@ import numpy as np import transformers from huggingface_hub import Repository, delete_repo, login from requests.exceptions import HTTPError -from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForSequenceClassification, + PretrainedConfig, + is_torch_available, + logging, +) from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available from transformers.models.auto import get_values from transformers.testing_utils import ( @@ -67,7 +74,6 @@ if is_torch_available(): AdaptiveEmbedding, BertConfig, BertModel, - PretrainedConfig, PreTrainedModel, T5Config, T5ForConditionalGeneration, @@ -2078,6 +2084,23 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(model.dtype, torch.float16) +class FakeConfig(PretrainedConfig): + def __init__(self, attribute=1, **kwargs): + self.attribute = attribute + super().__init__(**kwargs) + + +# Make sure this is synchronized with the config above. +FAKE_CONFIG_CODE = """ +from transformers import PretrainedConfig + +class FakeConfig(PretrainedConfig): + def __init__(self, attribute=1, **kwargs): + self.attribute = attribute + super().__init__(**kwargs) +""" + + if is_torch_available(): class FakeModel(PreTrainedModel): @@ -2140,6 +2163,11 @@ class ModelPushToHubTester(unittest.TestCase): except HTTPError: pass + try: + delete_repo(token=cls._token, name="test-dynamic-model-config") + except HTTPError: + pass + def test_push_to_hub(self): config = BertConfig( vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 @@ -2185,5 +2213,47 @@ class ModelPushToHubTester(unittest.TestCase): repo.push_to_hub() new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) + # Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module + self.assertEqual(new_model.__class__.__name__, "FakeModel") for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + + config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model") + new_model = AutoModel.from_config(config, trust_remote_code=True) + self.assertEqual(new_model.__class__.__name__, "FakeModel") + + def test_push_to_hub_dynamic_model_and_config(self): + config = FakeConfig( + attribute=42, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + ) + config.auto_map = {"AutoConfig": "configuration.FakeConfig", "AutoModel": "modeling.FakeModel"} + model = FakeModel(config) + + with tempfile.TemporaryDirectory() as tmp_dir: + repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model-config", use_auth_token=self._token) + model.save_pretrained(tmp_dir) + with open(os.path.join(tmp_dir, "configuration.py"), "w") as f: + f.write(FAKE_CONFIG_CODE) + with open(os.path.join(tmp_dir, "modeling.py"), "w") as f: + f.write(FAKE_MODEL_CODE) + + repo.push_to_hub() + + new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model-config", trust_remote_code=True) + # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module + self.assertEqual(new_model.config.__class__.__name__, "FakeConfig") + self.assertEqual(new_model.config.attribute, 42) + + # Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module + self.assertEqual(new_model.__class__.__name__, "FakeModel") + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model") + new_model = AutoModel.from_config(config, trust_remote_code=True) + self.assertEqual(new_model.__class__.__name__, "FakeModel") diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index d733d2d4b0a..685ccf43af1 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -27,11 +27,12 @@ from collections import OrderedDict from itertools import takewhile from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union -from huggingface_hub import delete_repo, login +from huggingface_hub import Repository, delete_repo, login from requests.exceptions import HTTPError from transformers import ( AlbertTokenizer, AlbertTokenizerFast, + AutoTokenizer, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, @@ -41,6 +42,7 @@ from transformers import ( Trainer, TrainingArguments, is_tf_available, + is_tokenizers_available, is_torch_available, ) from transformers.testing_utils import ( @@ -3513,6 +3515,28 @@ class TokenizerTesterMixin: self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint"))) +class FakeTokenizer(BertTokenizer): + pass + + +if is_tokenizers_available(): + + class FakeTokenizerFast(BertTokenizerFast): + pass + + +# Make sure this is synchronized with the tokenizers above. +FAKE_TOKENIZER_CODE = """ +from transformers import BertTokenizer, BertTokenizerFast + +class FakeTokenizer(BertTokenizer): + pass + +class FakeTokenizerFast(BertTokenizerFast): + pass +""" + + @is_staging_test class TokenizerPushToHubTester(unittest.TestCase): vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"] @@ -3533,6 +3557,11 @@ class TokenizerPushToHubTester(unittest.TestCase): except HTTPError: pass + try: + delete_repo(token=cls._token, name="test-dynamic-tokenizer") + except HTTPError: + pass + def test_push_to_hub(self): with tempfile.TemporaryDirectory() as tmp_dir: vocab_file = os.path.join(tmp_dir, "vocab.txt") @@ -3562,6 +3591,48 @@ class TokenizerPushToHubTester(unittest.TestCase): new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org") self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) + def test_push_to_hub_dynamic_tokenizer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + vocab_file = os.path.join(tmp_dir, "vocab.txt") + with open(vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) + tokenizer = FakeTokenizer(vocab_file) + + # No fast custom tokenizer + tokenizer._auto_map = ("tokenizer.FakeTokenizer", None) + with tempfile.TemporaryDirectory() as tmp_dir: + repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token) + print(os.listdir((tmp_dir))) + tokenizer.save_pretrained(tmp_dir) + with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f: + f.write(FAKE_TOKENIZER_CODE) + + repo.push_to_hub() + + tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True) + # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module + self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer") + + # Fast and slow custom tokenizer + tokenizer._auto_map = ("tokenizer.FakeTokenizer", "tokenizer.FakeTokenizerFast") + with tempfile.TemporaryDirectory() as tmp_dir: + repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token) + print(os.listdir((tmp_dir))) + tokenizer.save_pretrained(tmp_dir) + with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f: + f.write(FAKE_TOKENIZER_CODE) + + repo.push_to_hub() + + tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True) + # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module + self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizerFast") + tokenizer = AutoTokenizer.from_pretrained( + f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True + ) + # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module + self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer") + class TrieTest(unittest.TestCase): def test_trie(self):