Expand dynamic supported objects to configs and tokenizers (#14296)

* Dynamic configs

* Add config test

* Better tests

* Add tokenizer and test

* Add to from_config

* With save
This commit is contained in:
Sylvain Gugger 2021-11-08 15:28:25 -05:00 committed by GitHub
parent de635af3f1
commit dfb00bf644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 272 additions and 10 deletions

View File

@ -378,7 +378,24 @@ class _BaseAutoModelClass:
@classmethod @classmethod
def from_config(cls, config, **kwargs): def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys(): trust_remote_code = kwargs.pop("trust_remote_code", False)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
raise ValueError(
"Loading this model requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)
class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping) model_class = _get_model_class(config, cls._model_mapping)
return model_class._from_config(config, **kwargs) return model_class._from_config(config, **kwargs)
@ -394,7 +411,7 @@ class _BaseAutoModelClass:
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained( config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs
) )
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code: if not trust_remote_code:

View File

@ -21,8 +21,12 @@ from typing import List, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import CONFIG_NAME from ...file_utils import CONFIG_NAME
from ...utils import logging
from .dynamic import get_class_from_dynamic_module
logger = logging.get_logger(__name__)
CONFIG_MAPPING_NAMES = OrderedDict( CONFIG_MAPPING_NAMES = OrderedDict(
[ [
# Add configs here # Add configs here
@ -523,6 +527,10 @@ class AutoConfig:
If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored. the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs(additional keyword arguments, `optional`): kwargs(additional keyword arguments, `optional`):
The values in kwargs of any keys which are configuration attributes will be used to override the loaded The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
@ -555,8 +563,28 @@ class AutoConfig:
{'foo': False} {'foo': False}
""" """
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
kwargs["name_or_path"] = pretrained_model_name_or_path
trust_remote_code = kwargs.pop("trust_remote_code", False)
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict: if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
"ensure no malicious code has been contributed in a newer revision."
)
class_ref = config_dict["auto_map"]["AutoConfig"]
module_file, class_name = class_ref.split(".")
config_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]] config_class = CONFIG_MAPPING[config_dict["model_type"]]
return config_class.from_dict(config_dict, **kwargs) return config_class.from_dict(config_dict, **kwargs)
else: else:

View File

@ -41,6 +41,7 @@ from .configuration_auto import (
model_type_to_module_name, model_type_to_module_name,
replace_list_option_in_docstrings, replace_list_option_in_docstrings,
) )
from .dynamic import get_class_from_dynamic_module
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -412,6 +413,10 @@ class AutoTokenizer:
Whether or not to try to load the fast version of the tokenizer. Whether or not to try to load the fast version of the tokenizer.
tokenizer_type (:obj:`str`, `optional`): tokenizer_type (:obj:`str`, `optional`):
Tokenizer type to be loaded. Tokenizer type to be loaded.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`): kwargs (additional keyword arguments, `optional`):
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
@ -436,6 +441,7 @@ class AutoTokenizer:
use_fast = kwargs.pop("use_fast", True) use_fast = kwargs.pop("use_fast", True)
tokenizer_type = kwargs.pop("tokenizer_type", None) tokenizer_type = kwargs.pop("tokenizer_type", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
# First, let's see whether the tokenizer_type is passed so that we can leverage it # First, let's see whether the tokenizer_type is passed so that we can leverage it
if tokenizer_type is not None: if tokenizer_type is not None:
@ -464,17 +470,45 @@ class AutoTokenizer:
# Next, let's try to use the tokenizer_config file to get the tokenizer class. # Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
config_tokenizer_class = tokenizer_config.get("tokenizer_class") config_tokenizer_class = tokenizer_config.get("tokenizer_class")
tokenizer_auto_map = tokenizer_config.get("auto_map")
# If that did not work, let's try to use the config. # If that did not work, let's try to use the config.
if config_tokenizer_class is None: if config_tokenizer_class is None:
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
config_tokenizer_class = config.tokenizer_class config_tokenizer_class = config.tokenizer_class
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
tokenizer_auto_map = config.auto_map["AutoTokenizer"]
# If we have the tokenizer class from the tokenizer config or the model config we're good! # If we have the tokenizer class from the tokenizer config or the model config we're good!
if config_tokenizer_class is not None: if config_tokenizer_class is not None:
tokenizer_class = None tokenizer_class = None
if use_fast and not config_tokenizer_class.endswith("Fast"): if tokenizer_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)
if use_fast and tokenizer_auto_map[1] is not None:
class_ref = tokenizer_auto_map[1]
else:
class_ref = tokenizer_auto_map[0]
module_file, class_name = class_ref.split(".")
tokenizer_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
elif use_fast and not config_tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config_tokenizer_class}Fast" tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None: if tokenizer_class is None:

View File

@ -1784,6 +1784,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
config_tokenizer_class = init_kwargs.get("tokenizer_class") config_tokenizer_class = init_kwargs.get("tokenizer_class")
init_kwargs.pop("tokenizer_class", None) init_kwargs.pop("tokenizer_class", None)
init_kwargs.pop("auto_map", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ()) saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs: if not init_inputs:
init_inputs = saved_init_inputs init_inputs = saved_init_inputs
@ -2028,6 +2029,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast": if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
tokenizer_class = tokenizer_class[:-4] tokenizer_class = tokenizer_class[:-4]
tokenizer_config["tokenizer_class"] = tokenizer_class tokenizer_config["tokenizer_class"] = tokenizer_class
if getattr(self, "_auto_map", None) is not None:
tokenizer_config["auto_map"] = self._auto_map
with open(tokenizer_config_file, "w", encoding="utf-8") as f: with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False)) f.write(json.dumps(tokenizer_config, ensure_ascii=False))

View File

@ -19,9 +19,9 @@ import os
import tempfile import tempfile
import unittest import unittest
from huggingface_hub import delete_repo, login from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import BertConfig, GPT2Config, is_torch_available from transformers import AutoConfig, BertConfig, GPT2Config, is_torch_available
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import PASS, USER, is_staging_test from transformers.testing_utils import PASS, USER, is_staging_test
@ -190,6 +190,23 @@ class ConfigTester(object):
self.check_config_arguments_init() self.check_config_arguments_init()
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""
@is_staging_test @is_staging_test
class ConfigPushToHubTester(unittest.TestCase): class ConfigPushToHubTester(unittest.TestCase):
@classmethod @classmethod
@ -208,6 +225,11 @@ class ConfigPushToHubTester(unittest.TestCase):
except HTTPError: except HTTPError:
pass pass
try:
delete_repo(token=cls._token, name="test-dynamic-config")
except HTTPError:
pass
def test_push_to_hub(self): def test_push_to_hub(self):
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
@ -238,6 +260,23 @@ class ConfigPushToHubTester(unittest.TestCase):
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_dynamic_config(self):
config = FakeConfig(attribute=42)
config.auto_map = {"AutoConfig": "configuration.FakeConfig"}
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
config.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)
repo.push_to_hub()
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "FakeConfig")
self.assertEqual(new_config.attribute, 42)
class ConfigTestUtils(unittest.TestCase): class ConfigTestUtils(unittest.TestCase):
def test_config_from_string(self): def test_config_from_string(self):

View File

@ -30,7 +30,14 @@ import numpy as np
import transformers import transformers
from huggingface_hub import Repository, delete_repo, login from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging from transformers import (
AutoConfig,
AutoModel,
AutoModelForSequenceClassification,
PretrainedConfig,
is_torch_available,
logging,
)
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import (
@ -67,7 +74,6 @@ if is_torch_available():
AdaptiveEmbedding, AdaptiveEmbedding,
BertConfig, BertConfig,
BertModel, BertModel,
PretrainedConfig,
PreTrainedModel, PreTrainedModel,
T5Config, T5Config,
T5ForConditionalGeneration, T5ForConditionalGeneration,
@ -2078,6 +2084,23 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.dtype, torch.float16) self.assertEqual(model.dtype, torch.float16)
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""
if is_torch_available(): if is_torch_available():
class FakeModel(PreTrainedModel): class FakeModel(PreTrainedModel):
@ -2140,6 +2163,11 @@ class ModelPushToHubTester(unittest.TestCase):
except HTTPError: except HTTPError:
pass pass
try:
delete_repo(token=cls._token, name="test-dynamic-model-config")
except HTTPError:
pass
def test_push_to_hub(self): def test_push_to_hub(self):
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
@ -2185,5 +2213,47 @@ class ModelPushToHubTester(unittest.TestCase):
repo.push_to_hub() repo.push_to_hub()
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")
def test_push_to_hub_dynamic_model_and_config(self):
config = FakeConfig(
attribute=42,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
)
config.auto_map = {"AutoConfig": "configuration.FakeConfig", "AutoModel": "modeling.FakeModel"}
model = FakeModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model-config", use_auth_token=self._token)
model.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
f.write(FAKE_MODEL_CODE)
repo.push_to_hub()
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model-config", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(new_model.config.__class__.__name__, "FakeConfig")
self.assertEqual(new_model.config.attribute, 42)
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")

View File

@ -27,11 +27,12 @@ from collections import OrderedDict
from itertools import takewhile from itertools import takewhile
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from huggingface_hub import delete_repo, login from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import ( from transformers import (
AlbertTokenizer, AlbertTokenizer,
AlbertTokenizerFast, AlbertTokenizerFast,
AutoTokenizer,
BertTokenizer, BertTokenizer,
BertTokenizerFast, BertTokenizerFast,
PreTrainedTokenizer, PreTrainedTokenizer,
@ -41,6 +42,7 @@ from transformers import (
Trainer, Trainer,
TrainingArguments, TrainingArguments,
is_tf_available, is_tf_available,
is_tokenizers_available,
is_torch_available, is_torch_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
@ -3513,6 +3515,28 @@ class TokenizerTesterMixin:
self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint"))) self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint")))
class FakeTokenizer(BertTokenizer):
pass
if is_tokenizers_available():
class FakeTokenizerFast(BertTokenizerFast):
pass
# Make sure this is synchronized with the tokenizers above.
FAKE_TOKENIZER_CODE = """
from transformers import BertTokenizer, BertTokenizerFast
class FakeTokenizer(BertTokenizer):
pass
class FakeTokenizerFast(BertTokenizerFast):
pass
"""
@is_staging_test @is_staging_test
class TokenizerPushToHubTester(unittest.TestCase): class TokenizerPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"] vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
@ -3533,6 +3557,11 @@ class TokenizerPushToHubTester(unittest.TestCase):
except HTTPError: except HTTPError:
pass pass
try:
delete_repo(token=cls._token, name="test-dynamic-tokenizer")
except HTTPError:
pass
def test_push_to_hub(self): def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") vocab_file = os.path.join(tmp_dir, "vocab.txt")
@ -3562,6 +3591,48 @@ class TokenizerPushToHubTester(unittest.TestCase):
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org") new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
def test_push_to_hub_dynamic_tokenizer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = FakeTokenizer(vocab_file)
# No fast custom tokenizer
tokenizer._auto_map = ("tokenizer.FakeTokenizer", None)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
print(os.listdir((tmp_dir)))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
f.write(FAKE_TOKENIZER_CODE)
repo.push_to_hub()
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")
# Fast and slow custom tokenizer
tokenizer._auto_map = ("tokenizer.FakeTokenizer", "tokenizer.FakeTokenizerFast")
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
print(os.listdir((tmp_dir)))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
f.write(FAKE_TOKENIZER_CODE)
repo.push_to_hub()
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizerFast")
tokenizer = AutoTokenizer.from_pretrained(
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")
class TrieTest(unittest.TestCase): class TrieTest(unittest.TestCase):
def test_trie(self): def test_trie(self):