From 45f56580a7e11b5b894374f8e1c7bdd54d982682 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 15 Feb 2022 09:44:35 -0500 Subject: [PATCH] Allow custom code for Processors (#15649) * Allow custom code for Processors * Add more test * Test all auto_map configs are properly set --- src/transformers/dynamic_module_utils.py | 23 ++-- .../models/auto/processing_auto.py | 89 ++++++++++----- .../models/auto/tokenization_auto.py | 8 +- src/transformers/processing_utils.py | 48 +++++++- tests/test_feature_extraction_common.py | 4 +- tests/test_processor_auto.py | 106 +++++++++++++++++- tests/test_tokenization_auto.py | 32 ++++++ tests/test_tokenization_common.py | 13 ++- utils/test_module/custom_processing.py | 6 + 9 files changed, 288 insertions(+), 41 deletions(-) create mode 100644 utils/test_module/custom_processing.py diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 7ce71ac75f9..91f5bb36a96 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -395,8 +395,8 @@ def custom_object_save(obj, folder, config=None): "this code in a separate module so we can include it in the saved folder and make it easier to share via " "the Hub." ) - # Add object class to the config auto_map - if config is not None: + + def _set_auto_map_in_config(_config): module_name = obj.__class__.__module__ last_module = module_name.split(".")[-1] full_name = f"{last_module}.{obj.__class__.__name__}" @@ -418,12 +418,21 @@ def custom_object_save(obj, folder, config=None): full_name = (slow_tokenizer_class, fast_tokenizer_class) - if isinstance(config, dict): - config["auto_map"] = full_name - elif getattr(config, "auto_map", None) is not None: - config.auto_map[obj._auto_class] = full_name + if isinstance(_config, dict): + auto_map = _config.get("auto_map", {}) + auto_map[obj._auto_class] = full_name + _config["auto_map"] = auto_map + elif getattr(_config, "auto_map", None) is not None: + _config.auto_map[obj._auto_class] = full_name else: - config.auto_map = {obj._auto_class: full_name} + _config.auto_map = {obj._auto_class: full_name} + + # Add object class to the config auto_map + if isinstance(config, (list, tuple)): + for cfg in config: + _set_auto_map_in_config(cfg) + elif config is not None: + _set_auto_map_in_config(config) # Copy module file to the output folder. object_file = sys.modules[obj.__module__].__file__ diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 5a788e16b8a..7b1365a3e3c 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -20,19 +20,22 @@ from collections import OrderedDict # Build the list of all feature extractors from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module from ...feature_extraction_utils import FeatureExtractionMixin from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo from ...tokenization_utils import TOKENIZER_CONFIG_FILE +from ...utils import logging from .auto_factory import _LazyAutoMapping from .configuration_auto import ( CONFIG_MAPPING_NAMES, AutoConfig, - config_class_to_model_type, model_type_to_module_name, replace_list_option_in_docstrings, ) +logger = logging.get_logger(__name__) + PROCESSOR_MAPPING_NAMES = OrderedDict( [ ("clip", "CLIPProcessor"), @@ -120,6 +123,10 @@ class AutoProcessor: functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. kwargs (`Dict[str, Any]`, *optional*): The values in kwargs of any keys which are feature extractor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is @@ -143,10 +150,14 @@ class AutoProcessor: >>> processor = AutoProcessor.from_pretrained("./test/saved_model/") ```""" config = kwargs.pop("config", None) + trust_remote_code = kwargs.pop("trust_remote_code", False) kwargs["_from_auto"] = True + processor_class = None + processor_auto_map = None + # First, let's see if we have a preprocessor config. - # Filter the kwargs for `get_file_from_repo``. + # Filter the kwargs for `get_file_from_repo`. get_file_from_repo_kwargs = { key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs } @@ -156,35 +167,63 @@ class AutoProcessor: ) if preprocessor_config_file is not None: config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) - if "processor_class" in config_dict: - processor_class = processor_class_from_name(config_dict["processor_class"]) - return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + processor_class = config_dict.get("processor_class", None) + if "AutoProcessor" in config_dict.get("auto_map", {}): + processor_auto_map = config_dict["auto_map"]["AutoProcessor"] - # Next, let's check whether the processor class is saved in a tokenizer - # Let's start by checking whether the processor class is saved in a feature extractor - tokenizer_config_file = get_file_from_repo( - pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs - ) - if tokenizer_config_file is not None: - with open(tokenizer_config_file, encoding="utf-8") as reader: - config_dict = json.load(reader) + if processor_class is None: + # Next, let's check whether the processor class is saved in a tokenizer + tokenizer_config_file = get_file_from_repo( + pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs + ) + if tokenizer_config_file is not None: + with open(tokenizer_config_file, encoding="utf-8") as reader: + config_dict = json.load(reader) - if "processor_class" in config_dict: - processor_class = processor_class_from_name(config_dict["processor_class"]) - return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + processor_class = config_dict.get("processor_class", None) + if "AutoProcessor" in config_dict.get("auto_map", {}): + processor_auto_map = config_dict["auto_map"]["AutoProcessor"] - # Otherwise, load config, if it can be loaded. - if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + if processor_class is None: + # Otherwise, load config, if it can be loaded. + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) - model_type = config_class_to_model_type(type(config).__name__) + # And check if the config contains the processor class. + processor_class = getattr(config, "processor_class", None) + if hasattr(config, "auto_map") and "AutoProcessor" in config.auto_map: + processor_auto_map = config.auto_map["AutoProcessor"] - if getattr(config, "processor_class", None) is not None: - processor_class = processor_class_from_name(config.processor_class) - return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + if processor_class is not None: + # If we have custom code for a feature extractor, we get the proper class. + if processor_auto_map is not None: + if not trust_remote_code: + raise ValueError( + f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file " + "in that repo on your local machine. Make sure you have read the code there to avoid " + "malicious use, then set the option `trust_remote_code=True` to remove this error." + ) + if kwargs.get("revision", None) is None: + logger.warning( + "Explicitly passing a `revision` is encouraged when loading a feature extractor with custom " + "code to ensure no malicious code has been contributed in a newer revision." + ) - model_type = config_class_to_model_type(type(config).__name__) - if model_type is not None: + module_file, class_name = processor_auto_map.split(".") + processor_class = get_class_from_dynamic_module( + pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs + ) + else: + processor_class = processor_class_from_name(processor_class) + + return processor_class.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + + # Last try: we use the PROCESSOR_MAPPING. + if type(config) in PROCESSOR_MAPPING: return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs) raise ValueError( diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 043843fd52d..41d44c641f3 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -469,7 +469,13 @@ class AutoTokenizer: # Next, let's try to use the tokenizer_config file to get the tokenizer class. tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) config_tokenizer_class = tokenizer_config.get("tokenizer_class") - tokenizer_auto_map = tokenizer_config.get("auto_map") + tokenizer_auto_map = None + if "auto_map" in tokenizer_config: + if isinstance(tokenizer_config["auto_map"], (tuple, list)): + # Legacy format for dynamic tokenizers + tokenizer_auto_map = tokenizer_config["auto_map"] + else: + tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None) # If that did not work, let's try to use the config. if config_tokenizer_class is None: diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index ec6196c862a..83c85b1a203 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -17,10 +17,14 @@ """ import importlib.util +import os from pathlib import Path +from .dynamic_module_utils import custom_object_save +from .tokenization_utils_base import PreTrainedTokenizerBase -# Comment to write + +# Dynamically import the Transformers module to grab the attribute classes of the processor form their names. spec = importlib.util.spec_from_file_location( "transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent] ) @@ -42,6 +46,7 @@ class ProcessorMixin: # Names need to be attr_class for attr in attributes feature_extractor_class = None tokenizer_class = None + _auto_class = None # args have to match the attributes class attribute def __init__(self, *args, **kwargs): @@ -101,6 +106,14 @@ class ProcessorMixin: Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will be created if it does not exist). """ + os.makedirs(save_directory, exist_ok=True) + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + attrs = [getattr(self, attribute_name) for attribute_name in self.attributes] + configs = [(a.init_kwargs if isinstance(a, PreTrainedTokenizerBase) else a) for a in attrs] + custom_object_save(self, save_directory, config=configs) + for attribute_name in self.attributes: attribute = getattr(self, attribute_name) # Include the processor class in the attribute config so this processor can then be reloaded with the @@ -109,6 +122,13 @@ class ProcessorMixin: attribute._set_processor_class(self.__class__.__name__) attribute.save_pretrained(save_directory) + if self._auto_class is not None: + # We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up. + for attribute_name in self.attributes: + attribute = getattr(self, attribute_name) + if isinstance(attribute, PreTrainedTokenizerBase): + del attribute.init_kwargs["auto_map"] + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): r""" @@ -142,6 +162,32 @@ class ProcessorMixin: args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) return cls(*args) + @classmethod + def register_for_auto_class(cls, auto_class="AutoProcessor"): + """ + Register this class with a given auto class. This should only be used for custom feature extractors as the ones + in the library are already mapped with `AutoProcessor`. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`): + The auto class to register this new feature extractor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + @classmethod def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs): args = [] diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index 762bc6ec5fb..098d982b149 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -40,8 +40,6 @@ if is_torch_available(): if is_vision_available(): from PIL import Image -SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") - SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") @@ -119,7 +117,7 @@ class FeatureExtractionSavingTestMixin: @is_staging_test -class ConfigPushToHubTester(unittest.TestCase): +class FeatureExtractorPushToHubTester(unittest.TestCase): @classmethod def setUpClass(cls): cls._token = login(username=USER, password=PASS) diff --git a/tests/test_processor_auto.py b/tests/test_processor_auto.py index 6f64480a98d..7cbb5b06a9d 100644 --- a/tests/test_processor_auto.py +++ b/tests/test_processor_auto.py @@ -15,20 +15,34 @@ import json import os +import sys import tempfile import unittest +from pathlib import Path from shutil import copyfile +from huggingface_hub import Repository, delete_repo, login +from requests.exceptions import HTTPError from transformers import AutoProcessor, AutoTokenizer, Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2Processor -from transformers.file_utils import FEATURE_EXTRACTOR_NAME +from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_tokenizers_available +from transformers.testing_utils import PASS, USER, is_staging_test from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE +sys.path.append(str(Path(__file__).parent.parent / "utils")) + +from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402 +from test_module.custom_processing import CustomProcessor # noqa E402 +from test_module.custom_tokenization import CustomTokenizer # noqa E402 + + SAMPLE_PROCESSOR_CONFIG = os.path.join( os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" ) SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json") +SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") + class AutoFeatureExtractorTest(unittest.TestCase): def test_processor_from_model_shortcut(self): @@ -115,3 +129,93 @@ class AutoFeatureExtractorTest(unittest.TestCase): processor = AutoProcessor.from_pretrained(tmpdirname) self.assertIsInstance(processor, Wav2Vec2Processor) + + def test_from_pretrained_dynamic_processor(self): + processor = AutoProcessor.from_pretrained("hf-internal-testing/test_dynamic_processor", trust_remote_code=True) + self.assertTrue(processor.special_attribute_present) + self.assertEqual(processor.__class__.__name__, "NewProcessor") + + feature_extractor = processor.feature_extractor + self.assertTrue(feature_extractor.special_attribute_present) + self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor") + + tokenizer = processor.tokenizer + self.assertTrue(tokenizer.special_attribute_present) + if is_tokenizers_available(): + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast") + + # Test we can also load the slow version + processor = AutoProcessor.from_pretrained( + "hf-internal-testing/test_dynamic_processor", trust_remote_code=True, use_fast=False + ) + tokenizer = processor.tokenizer + self.assertTrue(tokenizer.special_attribute_present) + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + else: + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + + +@is_staging_test +class ProcessorPushToHubTester(unittest.TestCase): + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"] + + @classmethod + def setUpClass(cls): + cls._token = login(username=USER, password=PASS) + + @classmethod + def tearDownClass(cls): + try: + delete_repo(token=cls._token, name="test-dynamic-processor") + except HTTPError: + pass + + def test_push_to_hub_dynamic_processor(self): + CustomFeatureExtractor.register_for_auto_class() + CustomTokenizer.register_for_auto_class() + CustomProcessor.register_for_auto_class() + + feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) + + with tempfile.TemporaryDirectory() as tmp_dir: + vocab_file = os.path.join(tmp_dir, "vocab.txt") + with open(vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) + tokenizer = CustomTokenizer(vocab_file) + + processor = CustomProcessor(feature_extractor, tokenizer) + + with tempfile.TemporaryDirectory() as tmp_dir: + repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-processor", use_auth_token=self._token) + processor.save_pretrained(tmp_dir) + + # This has added the proper auto_map field to the feature extractor config + self.assertDictEqual( + processor.feature_extractor.auto_map, + { + "AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor", + "AutoProcessor": "custom_processing.CustomProcessor", + }, + ) + + # This has added the proper auto_map field to the tokenizer config + with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f: + tokenizer_config = json.load(f) + self.assertDictEqual( + tokenizer_config["auto_map"], + { + "AutoTokenizer": ["custom_tokenization.CustomTokenizer", None], + "AutoProcessor": "custom_processing.CustomProcessor", + }, + ) + + # The code has been copied from fixtures + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py"))) + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_tokenization.py"))) + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_processing.py"))) + + repo.push_to_hub() + + new_processor = AutoProcessor.from_pretrained(f"{USER}/test-dynamic-processor", trust_remote_code=True) + # Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module + self.assertEqual(new_processor.__class__.__name__, "CustomProcessor") diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index a6608c90ca5..ae4e5896508 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -310,6 +310,38 @@ class AutoTokenizerTest(unittest.TestCase): if CustomConfig in TOKENIZER_MAPPING._extra_content: del TOKENIZER_MAPPING._extra_content[CustomConfig] + def test_from_pretrained_dynamic_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True) + self.assertTrue(tokenizer.special_attribute_present) + if is_tokenizers_available(): + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast") + + # Test we can also load the slow version + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True, use_fast=False + ) + self.assertTrue(tokenizer.special_attribute_present) + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + else: + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + + def test_from_pretrained_dynamic_tokenizer_legacy_format(self): + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/test_dynamic_tokenizer_legacy", trust_remote_code=True + ) + self.assertTrue(tokenizer.special_attribute_present) + if is_tokenizers_available(): + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast") + + # Test we can also load the slow version + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/test_dynamic_tokenizer_legacy", trust_remote_code=True, use_fast=False + ) + self.assertTrue(tokenizer.special_attribute_present) + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + else: + self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + def test_repo_not_found(self): with self.assertRaisesRegex( EnvironmentError, "bert-base is not a local folder and is not a valid model identifier" diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 44c55b423c6..e58ab9a816a 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -3812,7 +3812,9 @@ class TokenizerPushToHubTester(unittest.TestCase): with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f: tokenizer_config = json.load(f) - self.assertEqual(tokenizer_config["auto_map"], ["custom_tokenization.CustomTokenizer", None]) + self.assertDictEqual( + tokenizer_config["auto_map"], {"AutoTokenizer": ["custom_tokenization.CustomTokenizer", None]} + ) repo.push_to_hub() @@ -3837,9 +3839,14 @@ class TokenizerPushToHubTester(unittest.TestCase): with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f: tokenizer_config = json.load(f) - self.assertEqual( + self.assertDictEqual( tokenizer_config["auto_map"], - ["custom_tokenization.CustomTokenizer", "custom_tokenization_fast.CustomTokenizerFast"], + { + "AutoTokenizer": [ + "custom_tokenization.CustomTokenizer", + "custom_tokenization_fast.CustomTokenizerFast", + ] + }, ) repo.push_to_hub() diff --git a/utils/test_module/custom_processing.py b/utils/test_module/custom_processing.py new file mode 100644 index 00000000000..196fc511b65 --- /dev/null +++ b/utils/test_module/custom_processing.py @@ -0,0 +1,6 @@ +from transformers import ProcessorMixin + + +class CustomProcessor(ProcessorMixin): + feature_extractor_class = "AutoFeatureExtractor" + tokenizer_class = "AutoTokenizer"