diff --git a/.circleci/config.yml b/.circleci/config.yml index 0fa61008cee..aedbe4f5539 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -221,7 +221,7 @@ jobs: run_tests_custom_tokenizers: working_directory: ~/transformers docker: - - image: circleci/python:3.6 + - image: circleci/python:3.7 environment: RUN_CUSTOM_TOKENIZERS: yes steps: diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 5619b773339..a7e05d92e4e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -185,8 +185,6 @@ TOKENIZER_MAPPING = OrderedDict( (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)), (BartConfig, (BartTokenizer, BartTokenizerFast)), (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)), - (RobertaConfig, (BertweetTokenizer, None)), - (RobertaConfig, (PhobertTokenizer, None)), (RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)), (ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)), (ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)), @@ -195,7 +193,6 @@ TOKENIZER_MAPPING = OrderedDict( (LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)), (DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)), (SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)), - (BertConfig, (HerbertTokenizer, HerbertTokenizerFast)), (BertConfig, (BertTokenizer, BertTokenizerFast)), (OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)), (GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)), @@ -213,6 +210,16 @@ TOKENIZER_MAPPING = OrderedDict( ] ) +# For tokenizers which are not directly mapped from a config +NO_CONFIG_TOKENIZER = [ + BertJapaneseTokenizer, + BertweetTokenizer, + HerbertTokenizer, + HerbertTokenizerFast, + PhobertTokenizer, +] + + SLOW_TOKENIZER_MAPPING = { k: (v[0] if v[0] is not None else v[1]) for k, v in TOKENIZER_MAPPING.items() @@ -220,6 +227,17 @@ SLOW_TOKENIZER_MAPPING = { } +def tokenizer_class_from_name(class_name: str): + all_tokenizer_classes = ( + [v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None] + + [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None] + + NO_CONFIG_TOKENIZER + ) + for c in all_tokenizer_classes: + if c.__name__ == class_name: + return c + + class AutoTokenizer: r""" This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when @@ -307,17 +325,17 @@ class AutoTokenizer: if not isinstance(config, PretrainedConfig): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - if "bert-base-japanese" in str(pretrained_model_name_or_path): - return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - use_fast = kwargs.pop("use_fast", True) if config.tokenizer_class is not None: + tokenizer_class = None if use_fast and not config.tokenizer_class.endswith("Fast"): tokenizer_class_candidate = f"{config.tokenizer_class}Fast" - else: + tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) + if tokenizer_class is None: tokenizer_class_candidate = config.tokenizer_class - tokenizer_class = globals().get(tokenizer_class_candidate) + tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) + if tokenizer_class is None: raise ValueError( "Tokenizer class {} does not exist or is not currently imported.".format(tokenizer_class_candidate) diff --git a/tests/test_tokenization_bert_japanese.py b/tests/test_tokenization_bert_japanese.py index 55ae6f41c47..1424652427e 100644 --- a/tests/test_tokenization_bert_japanese.py +++ b/tests/test_tokenization_bert_japanese.py @@ -18,6 +18,7 @@ import os import pickle import unittest +from transformers import AutoTokenizer from transformers.models.bert_japanese.tokenization_bert_japanese import ( VOCAB_FILES_NAMES, BertJapaneseTokenizer, @@ -267,3 +268,11 @@ class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestC # 2 is for "[CLS]", 3 is for "[SEP]" assert encoded_sentence == [2] + text + [3] assert encoded_pair == [2] + text + [3] + text_2 + [3] + + +@custom_tokenizers +class AutoTokenizerCustomTest(unittest.TestCase): + def test_tokenizer_bert_japanese(self): + EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese" + tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID) + self.assertIsInstance(tokenizer, BertJapaneseTokenizer)