Allow custom code for Processors (#15649)

* Allow custom code for Processors

* Add more test

* Test all auto_map configs are properly set
This commit is contained in:
Sylvain Gugger 2022-02-15 09:44:35 -05:00 committed by GitHub
parent 86a7845c0c
commit 45f56580a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 288 additions and 41 deletions

View File

@ -395,8 +395,8 @@ def custom_object_save(obj, folder, config=None):
"this code in a separate module so we can include it in the saved folder and make it easier to share via " "this code in a separate module so we can include it in the saved folder and make it easier to share via "
"the Hub." "the Hub."
) )
# Add object class to the config auto_map
if config is not None: def _set_auto_map_in_config(_config):
module_name = obj.__class__.__module__ module_name = obj.__class__.__module__
last_module = module_name.split(".")[-1] last_module = module_name.split(".")[-1]
full_name = f"{last_module}.{obj.__class__.__name__}" full_name = f"{last_module}.{obj.__class__.__name__}"
@ -418,12 +418,21 @@ def custom_object_save(obj, folder, config=None):
full_name = (slow_tokenizer_class, fast_tokenizer_class) full_name = (slow_tokenizer_class, fast_tokenizer_class)
if isinstance(config, dict): if isinstance(_config, dict):
config["auto_map"] = full_name auto_map = _config.get("auto_map", {})
elif getattr(config, "auto_map", None) is not None: auto_map[obj._auto_class] = full_name
config.auto_map[obj._auto_class] = full_name _config["auto_map"] = auto_map
elif getattr(_config, "auto_map", None) is not None:
_config.auto_map[obj._auto_class] = full_name
else: else:
config.auto_map = {obj._auto_class: full_name} _config.auto_map = {obj._auto_class: full_name}
# Add object class to the config auto_map
if isinstance(config, (list, tuple)):
for cfg in config:
_set_auto_map_in_config(cfg)
elif config is not None:
_set_auto_map_in_config(config)
# Copy module file to the output folder. # Copy module file to the output folder.
object_file = sys.modules[obj.__module__].__file__ object_file = sys.modules[obj.__module__].__file__

View File

@ -20,19 +20,22 @@ from collections import OrderedDict
# Build the list of all feature extractors # Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo
from ...tokenization_utils import TOKENIZER_CONFIG_FILE from ...tokenization_utils import TOKENIZER_CONFIG_FILE
from ...utils import logging
from .auto_factory import _LazyAutoMapping from .auto_factory import _LazyAutoMapping
from .configuration_auto import ( from .configuration_auto import (
CONFIG_MAPPING_NAMES, CONFIG_MAPPING_NAMES,
AutoConfig, AutoConfig,
config_class_to_model_type,
model_type_to_module_name, model_type_to_module_name,
replace_list_option_in_docstrings, replace_list_option_in_docstrings,
) )
logger = logging.get_logger(__name__)
PROCESSOR_MAPPING_NAMES = OrderedDict( PROCESSOR_MAPPING_NAMES = OrderedDict(
[ [
("clip", "CLIPProcessor"), ("clip", "CLIPProcessor"),
@ -120,6 +123,10 @@ class AutoProcessor:
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are feature extractor attributes will be used to override the The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
@ -143,10 +150,14 @@ class AutoProcessor:
>>> processor = AutoProcessor.from_pretrained("./test/saved_model/") >>> processor = AutoProcessor.from_pretrained("./test/saved_model/")
```""" ```"""
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
processor_class = None
processor_auto_map = None
# First, let's see if we have a preprocessor config. # First, let's see if we have a preprocessor config.
# Filter the kwargs for `get_file_from_repo``. # Filter the kwargs for `get_file_from_repo`.
get_file_from_repo_kwargs = { get_file_from_repo_kwargs = {
key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs
} }
@ -156,12 +167,12 @@ class AutoProcessor:
) )
if preprocessor_config_file is not None: if preprocessor_config_file is not None:
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "processor_class" in config_dict: processor_class = config_dict.get("processor_class", None)
processor_class = processor_class_from_name(config_dict["processor_class"]) if "AutoProcessor" in config_dict.get("auto_map", {}):
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
if processor_class is None:
# Next, let's check whether the processor class is saved in a tokenizer # Next, let's check whether the processor class is saved in a tokenizer
# Let's start by checking whether the processor class is saved in a feature extractor
tokenizer_config_file = get_file_from_repo( tokenizer_config_file = get_file_from_repo(
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs
) )
@ -169,22 +180,50 @@ class AutoProcessor:
with open(tokenizer_config_file, encoding="utf-8") as reader: with open(tokenizer_config_file, encoding="utf-8") as reader:
config_dict = json.load(reader) config_dict = json.load(reader)
if "processor_class" in config_dict: processor_class = config_dict.get("processor_class", None)
processor_class = processor_class_from_name(config_dict["processor_class"]) if "AutoProcessor" in config_dict.get("auto_map", {}):
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
if processor_class is None:
# Otherwise, load config, if it can be loaded. # Otherwise, load config, if it can be loaded.
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
model_type = config_class_to_model_type(type(config).__name__) # And check if the config contains the processor class.
processor_class = getattr(config, "processor_class", None)
if hasattr(config, "auto_map") and "AutoProcessor" in config.auto_map:
processor_auto_map = config.auto_map["AutoProcessor"]
if getattr(config, "processor_class", None) is not None: if processor_class is not None:
processor_class = processor_class_from_name(config.processor_class) # If we have custom code for a feature extractor, we get the proper class.
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) if processor_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file "
"in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
"code to ensure no malicious code has been contributed in a newer revision."
)
model_type = config_class_to_model_type(type(config).__name__) module_file, class_name = processor_auto_map.split(".")
if model_type is not None: processor_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
else:
processor_class = processor_class_from_name(processor_class)
return processor_class.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
# Last try: we use the PROCESSOR_MAPPING.
if type(config) in PROCESSOR_MAPPING:
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs) return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError( raise ValueError(

View File

@ -469,7 +469,13 @@ class AutoTokenizer:
# Next, let's try to use the tokenizer_config file to get the tokenizer class. # Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class") config_tokenizer_class = tokenizer_config.get("tokenizer_class")
tokenizer_auto_map = tokenizer_config.get("auto_map") tokenizer_auto_map = None
if "auto_map" in tokenizer_config:
if isinstance(tokenizer_config["auto_map"], (tuple, list)):
# Legacy format for dynamic tokenizers
tokenizer_auto_map = tokenizer_config["auto_map"]
else:
tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None)
# If that did not work, let's try to use the config. # If that did not work, let's try to use the config.
if config_tokenizer_class is None: if config_tokenizer_class is None:

View File

@ -17,10 +17,14 @@
""" """
import importlib.util import importlib.util
import os
from pathlib import Path from pathlib import Path
from .dynamic_module_utils import custom_object_save
from .tokenization_utils_base import PreTrainedTokenizerBase
# Comment to write
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
"transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent] "transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent]
) )
@ -42,6 +46,7 @@ class ProcessorMixin:
# Names need to be attr_class for attr in attributes # Names need to be attr_class for attr in attributes
feature_extractor_class = None feature_extractor_class = None
tokenizer_class = None tokenizer_class = None
_auto_class = None
# args have to match the attributes class attribute # args have to match the attributes class attribute
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -101,6 +106,14 @@ class ProcessorMixin:
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
be created if it does not exist). be created if it does not exist).
""" """
os.makedirs(save_directory, exist_ok=True)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
attrs = [getattr(self, attribute_name) for attribute_name in self.attributes]
configs = [(a.init_kwargs if isinstance(a, PreTrainedTokenizerBase) else a) for a in attrs]
custom_object_save(self, save_directory, config=configs)
for attribute_name in self.attributes: for attribute_name in self.attributes:
attribute = getattr(self, attribute_name) attribute = getattr(self, attribute_name)
# Include the processor class in the attribute config so this processor can then be reloaded with the # Include the processor class in the attribute config so this processor can then be reloaded with the
@ -109,6 +122,13 @@ class ProcessorMixin:
attribute._set_processor_class(self.__class__.__name__) attribute._set_processor_class(self.__class__.__name__)
attribute.save_pretrained(save_directory) attribute.save_pretrained(save_directory)
if self._auto_class is not None:
# We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up.
for attribute_name in self.attributes:
attribute = getattr(self, attribute_name)
if isinstance(attribute, PreTrainedTokenizerBase):
del attribute.init_kwargs["auto_map"]
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" r"""
@ -142,6 +162,32 @@ class ProcessorMixin:
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(*args) return cls(*args)
@classmethod
def register_for_auto_class(cls, auto_class="AutoProcessor"):
"""
Register this class with a given auto class. This should only be used for custom feature extractors as the ones
in the library are already mapped with `AutoProcessor`.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`):
The auto class to register this new feature extractor with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
@classmethod @classmethod
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
args = [] args = []

View File

@ -40,8 +40,6 @@ if is_torch_available():
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
@ -119,7 +117,7 @@ class FeatureExtractionSavingTestMixin:
@is_staging_test @is_staging_test
class ConfigPushToHubTester(unittest.TestCase): class FeatureExtractorPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._token = login(username=USER, password=PASS) cls._token = login(username=USER, password=PASS)

View File

@ -15,20 +15,34 @@
import json import json
import os import os
import sys
import tempfile import tempfile
import unittest import unittest
from pathlib import Path
from shutil import copyfile from shutil import copyfile
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoProcessor, AutoTokenizer, Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2Processor from transformers import AutoProcessor, AutoTokenizer, Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_tokenizers_available
from transformers.testing_utils import PASS, USER, is_staging_test
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
from test_module.custom_processing import CustomProcessor # noqa E402
from test_module.custom_tokenization import CustomTokenizer # noqa E402
SAMPLE_PROCESSOR_CONFIG = os.path.join( SAMPLE_PROCESSOR_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
) )
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json") SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
class AutoFeatureExtractorTest(unittest.TestCase): class AutoFeatureExtractorTest(unittest.TestCase):
def test_processor_from_model_shortcut(self): def test_processor_from_model_shortcut(self):
@ -115,3 +129,93 @@ class AutoFeatureExtractorTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained(tmpdirname) processor = AutoProcessor.from_pretrained(tmpdirname)
self.assertIsInstance(processor, Wav2Vec2Processor) self.assertIsInstance(processor, Wav2Vec2Processor)
def test_from_pretrained_dynamic_processor(self):
processor = AutoProcessor.from_pretrained("hf-internal-testing/test_dynamic_processor", trust_remote_code=True)
self.assertTrue(processor.special_attribute_present)
self.assertEqual(processor.__class__.__name__, "NewProcessor")
feature_extractor = processor.feature_extractor
self.assertTrue(feature_extractor.special_attribute_present)
self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor")
tokenizer = processor.tokenizer
self.assertTrue(tokenizer.special_attribute_present)
if is_tokenizers_available():
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast")
# Test we can also load the slow version
processor = AutoProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_processor", trust_remote_code=True, use_fast=False
)
tokenizer = processor.tokenizer
self.assertTrue(tokenizer.special_attribute_present)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
else:
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
@is_staging_test
class ProcessorPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
@classmethod
def setUpClass(cls):
cls._token = login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-dynamic-processor")
except HTTPError:
pass
def test_push_to_hub_dynamic_processor(self):
CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)
processor = CustomProcessor(feature_extractor, tokenizer)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-processor", use_auth_token=self._token)
processor.save_pretrained(tmp_dir)
# This has added the proper auto_map field to the feature extractor config
self.assertDictEqual(
processor.feature_extractor.auto_map,
{
"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor",
"AutoProcessor": "custom_processing.CustomProcessor",
},
)
# This has added the proper auto_map field to the tokenizer config
with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
tokenizer_config = json.load(f)
self.assertDictEqual(
tokenizer_config["auto_map"],
{
"AutoTokenizer": ["custom_tokenization.CustomTokenizer", None],
"AutoProcessor": "custom_processing.CustomProcessor",
},
)
# The code has been copied from fixtures
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py")))
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_tokenization.py")))
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_processing.py")))
repo.push_to_hub()
new_processor = AutoProcessor.from_pretrained(f"{USER}/test-dynamic-processor", trust_remote_code=True)
# Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module
self.assertEqual(new_processor.__class__.__name__, "CustomProcessor")

View File

@ -310,6 +310,38 @@ class AutoTokenizerTest(unittest.TestCase):
if CustomConfig in TOKENIZER_MAPPING._extra_content: if CustomConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[CustomConfig] del TOKENIZER_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True)
self.assertTrue(tokenizer.special_attribute_present)
if is_tokenizers_available():
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast")
# Test we can also load the slow version
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True, use_fast=False
)
self.assertTrue(tokenizer.special_attribute_present)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
else:
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
def test_from_pretrained_dynamic_tokenizer_legacy_format(self):
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer_legacy", trust_remote_code=True
)
self.assertTrue(tokenizer.special_attribute_present)
if is_tokenizers_available():
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast")
# Test we can also load the slow version
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer_legacy", trust_remote_code=True, use_fast=False
)
self.assertTrue(tokenizer.special_attribute_present)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
else:
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
def test_repo_not_found(self): def test_repo_not_found(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier" EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"

View File

@ -3812,7 +3812,9 @@ class TokenizerPushToHubTester(unittest.TestCase):
with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f: with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
tokenizer_config = json.load(f) tokenizer_config = json.load(f)
self.assertEqual(tokenizer_config["auto_map"], ["custom_tokenization.CustomTokenizer", None]) self.assertDictEqual(
tokenizer_config["auto_map"], {"AutoTokenizer": ["custom_tokenization.CustomTokenizer", None]}
)
repo.push_to_hub() repo.push_to_hub()
@ -3837,9 +3839,14 @@ class TokenizerPushToHubTester(unittest.TestCase):
with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f: with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
tokenizer_config = json.load(f) tokenizer_config = json.load(f)
self.assertEqual( self.assertDictEqual(
tokenizer_config["auto_map"], tokenizer_config["auto_map"],
["custom_tokenization.CustomTokenizer", "custom_tokenization_fast.CustomTokenizerFast"], {
"AutoTokenizer": [
"custom_tokenization.CustomTokenizer",
"custom_tokenization_fast.CustomTokenizerFast",
]
},
) )
repo.push_to_hub() repo.push_to_hub()

View File

@ -0,0 +1,6 @@
from transformers import ProcessorMixin
class CustomProcessor(ProcessorMixin):
feature_extractor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"