mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Expand dynamic supported objects to configs and tokenizers (#14296)
* Dynamic configs * Add config test * Better tests * Add tokenizer and test * Add to from_config * With save
This commit is contained in:
parent
de635af3f1
commit
dfb00bf644
@ -378,7 +378,24 @@ class _BaseAutoModelClass:
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, **kwargs):
|
||||
if type(config) in cls._model_mapping.keys():
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
"Loading this model requires you to execute the modeling file in that repo "
|
||||
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
|
||||
"the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
if kwargs.get("revision", None) is None:
|
||||
logger.warn(
|
||||
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
|
||||
"no malicious code has been contributed in a newer revision."
|
||||
)
|
||||
class_ref = config.auto_map[cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs)
|
||||
return model_class._from_config(config, **kwargs)
|
||||
elif type(config) in cls._model_mapping.keys():
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
return model_class._from_config(config, **kwargs)
|
||||
|
||||
@ -394,7 +411,7 @@ class _BaseAutoModelClass:
|
||||
kwargs["_from_auto"] = True
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
|
||||
if not trust_remote_code:
|
||||
|
@ -21,8 +21,12 @@ from typing import List, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import CONFIG_NAME
|
||||
from ...utils import logging
|
||||
from .dynamic import get_class_from_dynamic_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
@ -523,6 +527,10 @@ class AutoConfig:
|
||||
If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
|
||||
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
|
||||
the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
|
||||
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
|
||||
will execute code present on the Hub on your local machine.
|
||||
kwargs(additional keyword arguments, `optional`):
|
||||
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
||||
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
||||
@ -555,8 +563,28 @@ class AutoConfig:
|
||||
{'foo': False}
|
||||
"""
|
||||
kwargs["_from_auto"] = True
|
||||
kwargs["name_or_path"] = pretrained_model_name_or_path
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
if "model_type" in config_dict:
|
||||
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
|
||||
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
|
||||
"the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
if kwargs.get("revision", None) is None:
|
||||
logger.warn(
|
||||
"Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
|
||||
"ensure no malicious code has been contributed in a newer revision."
|
||||
)
|
||||
class_ref = config_dict["auto_map"]["AutoConfig"]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
config_class = get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||
)
|
||||
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
elif "model_type" in config_dict:
|
||||
config_class = CONFIG_MAPPING[config_dict["model_type"]]
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
else:
|
||||
|
@ -41,6 +41,7 @@ from .configuration_auto import (
|
||||
model_type_to_module_name,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
from .dynamic import get_class_from_dynamic_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -412,6 +413,10 @@ class AutoTokenizer:
|
||||
Whether or not to try to load the fast version of the tokenizer.
|
||||
tokenizer_type (:obj:`str`, `optional`):
|
||||
Tokenizer type to be loaded.
|
||||
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
|
||||
will execute code present on the Hub on your local machine.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
|
||||
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
|
||||
@ -436,6 +441,7 @@ class AutoTokenizer:
|
||||
|
||||
use_fast = kwargs.pop("use_fast", True)
|
||||
tokenizer_type = kwargs.pop("tokenizer_type", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
|
||||
# First, let's see whether the tokenizer_type is passed so that we can leverage it
|
||||
if tokenizer_type is not None:
|
||||
@ -464,17 +470,45 @@ class AutoTokenizer:
|
||||
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
|
||||
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
||||
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
||||
tokenizer_auto_map = tokenizer_config.get("auto_map")
|
||||
|
||||
# If that did not work, let's try to use the config.
|
||||
if config_tokenizer_class is None:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
config_tokenizer_class = config.tokenizer_class
|
||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
||||
tokenizer_auto_map = config.auto_map["AutoTokenizer"]
|
||||
|
||||
# If we have the tokenizer class from the tokenizer config or the model config we're good!
|
||||
if config_tokenizer_class is not None:
|
||||
tokenizer_class = None
|
||||
if use_fast and not config_tokenizer_class.endswith("Fast"):
|
||||
if tokenizer_auto_map is not None:
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that repo "
|
||||
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
|
||||
"the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
if kwargs.get("revision", None) is None:
|
||||
logger.warn(
|
||||
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
|
||||
"no malicious code has been contributed in a newer revision."
|
||||
)
|
||||
|
||||
if use_fast and tokenizer_auto_map[1] is not None:
|
||||
class_ref = tokenizer_auto_map[1]
|
||||
else:
|
||||
class_ref = tokenizer_auto_map[0]
|
||||
|
||||
module_file, class_name = class_ref.split(".")
|
||||
tokenizer_class = get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||
)
|
||||
|
||||
elif use_fast and not config_tokenizer_class.endswith("Fast"):
|
||||
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
|
||||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
||||
if tokenizer_class is None:
|
||||
|
@ -1784,6 +1784,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
|
||||
config_tokenizer_class = init_kwargs.get("tokenizer_class")
|
||||
init_kwargs.pop("tokenizer_class", None)
|
||||
init_kwargs.pop("auto_map", None)
|
||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||
if not init_inputs:
|
||||
init_inputs = saved_init_inputs
|
||||
@ -2028,6 +2029,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
|
||||
tokenizer_class = tokenizer_class[:-4]
|
||||
tokenizer_config["tokenizer_class"] = tokenizer_class
|
||||
if getattr(self, "_auto_map", None) is not None:
|
||||
tokenizer_config["auto_map"] = self._auto_map
|
||||
|
||||
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||
|
@ -19,9 +19,9 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import delete_repo, login
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import BertConfig, GPT2Config, is_torch_available
|
||||
from transformers import AutoConfig, BertConfig, GPT2Config, is_torch_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.testing_utils import PASS, USER, is_staging_test
|
||||
|
||||
@ -190,6 +190,23 @@ class ConfigTester(object):
|
||||
self.check_config_arguments_init()
|
||||
|
||||
|
||||
class FakeConfig(PretrainedConfig):
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
# Make sure this is synchronized with the config above.
|
||||
FAKE_CONFIG_CODE = """
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
class FakeConfig(PretrainedConfig):
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
super().__init__(**kwargs)
|
||||
"""
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
@ -208,6 +225,11 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-dynamic-config")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
@ -238,6 +260,23 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_dynamic_config(self):
|
||||
config = FakeConfig(attribute=42)
|
||||
config.auto_map = {"AutoConfig": "configuration.FakeConfig"}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
|
||||
config.save_pretrained(tmp_dir)
|
||||
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
|
||||
f.write(FAKE_CONFIG_CODE)
|
||||
|
||||
repo.push_to_hub()
|
||||
|
||||
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(new_config.__class__.__name__, "FakeConfig")
|
||||
self.assertEqual(new_config.attribute, 42)
|
||||
|
||||
|
||||
class ConfigTestUtils(unittest.TestCase):
|
||||
def test_config_from_string(self):
|
||||
|
@ -30,7 +30,14 @@ import numpy as np
|
||||
import transformers
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
PretrainedConfig,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
@ -67,7 +74,6 @@ if is_torch_available():
|
||||
AdaptiveEmbedding,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
@ -2078,6 +2084,23 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
|
||||
class FakeConfig(PretrainedConfig):
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
# Make sure this is synchronized with the config above.
|
||||
FAKE_CONFIG_CODE = """
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
class FakeConfig(PretrainedConfig):
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
super().__init__(**kwargs)
|
||||
"""
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
class FakeModel(PreTrainedModel):
|
||||
@ -2140,6 +2163,11 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-dynamic-model-config")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
@ -2185,5 +2213,47 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
repo.push_to_hub()
|
||||
|
||||
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
|
||||
self.assertEqual(new_model.__class__.__name__, "FakeModel")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
|
||||
new_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
self.assertEqual(new_model.__class__.__name__, "FakeModel")
|
||||
|
||||
def test_push_to_hub_dynamic_model_and_config(self):
|
||||
config = FakeConfig(
|
||||
attribute=42,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
)
|
||||
config.auto_map = {"AutoConfig": "configuration.FakeConfig", "AutoModel": "modeling.FakeModel"}
|
||||
model = FakeModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model-config", use_auth_token=self._token)
|
||||
model.save_pretrained(tmp_dir)
|
||||
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
|
||||
f.write(FAKE_CONFIG_CODE)
|
||||
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
|
||||
f.write(FAKE_MODEL_CODE)
|
||||
|
||||
repo.push_to_hub()
|
||||
|
||||
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model-config", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(new_model.config.__class__.__name__, "FakeConfig")
|
||||
self.assertEqual(new_model.config.attribute, 42)
|
||||
|
||||
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
|
||||
self.assertEqual(new_model.__class__.__name__, "FakeModel")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
|
||||
new_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
self.assertEqual(new_model.__class__.__name__, "FakeModel")
|
||||
|
@ -27,11 +27,12 @@ from collections import OrderedDict
|
||||
from itertools import takewhile
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
from huggingface_hub import delete_repo, login
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AlbertTokenizer,
|
||||
AlbertTokenizerFast,
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
BertTokenizerFast,
|
||||
PreTrainedTokenizer,
|
||||
@ -41,6 +42,7 @@ from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
@ -3513,6 +3515,28 @@ class TokenizerTesterMixin:
|
||||
self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint")))
|
||||
|
||||
|
||||
class FakeTokenizer(BertTokenizer):
|
||||
pass
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
|
||||
class FakeTokenizerFast(BertTokenizerFast):
|
||||
pass
|
||||
|
||||
|
||||
# Make sure this is synchronized with the tokenizers above.
|
||||
FAKE_TOKENIZER_CODE = """
|
||||
from transformers import BertTokenizer, BertTokenizerFast
|
||||
|
||||
class FakeTokenizer(BertTokenizer):
|
||||
pass
|
||||
|
||||
class FakeTokenizerFast(BertTokenizerFast):
|
||||
pass
|
||||
"""
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
|
||||
@ -3533,6 +3557,11 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-dynamic-tokenizer")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
@ -3562,6 +3591,48 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
def test_push_to_hub_dynamic_tokenizer(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = FakeTokenizer(vocab_file)
|
||||
|
||||
# No fast custom tokenizer
|
||||
tokenizer._auto_map = ("tokenizer.FakeTokenizer", None)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
|
||||
print(os.listdir((tmp_dir)))
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
|
||||
f.write(FAKE_TOKENIZER_CODE)
|
||||
|
||||
repo.push_to_hub()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")
|
||||
|
||||
# Fast and slow custom tokenizer
|
||||
tokenizer._auto_map = ("tokenizer.FakeTokenizer", "tokenizer.FakeTokenizerFast")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
|
||||
print(os.listdir((tmp_dir)))
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
|
||||
f.write(FAKE_TOKENIZER_CODE)
|
||||
|
||||
repo.push_to_hub()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizerFast")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
|
||||
)
|
||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")
|
||||
|
||||
|
||||
class TrieTest(unittest.TestCase):
|
||||
def test_trie(self):
|
||||
|
Loading…
Reference in New Issue
Block a user