Improve bert-japanese tokenizer handling (#8659)

* Make ci fail

* Try to make tests actually run?

* CI finally failing?

* Fix CI

* Revert "Fix CI"

This reverts commit ca7923be73.

* Ooops wrong one

* one more try

* Ok ok let's move this elsewhere

* Alternative to globals() (#8667)

* Alternative to globals()

* Error is raised later so return None

* Sentencepiece not installed make some tokenizers None

* Apply Lysandre wisdom

* Slightly clearer comment?

cc @sgugger

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Julien Chaumond 2020-11-23 17:15:02 +01:00 committed by GitHub
parent eec76615f6
commit 0cc5ab1333
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 9 deletions

View File

@ -221,7 +221,7 @@ jobs:
run_tests_custom_tokenizers:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
- image: circleci/python:3.7
environment:
RUN_CUSTOM_TOKENIZERS: yes
steps:

View File

@ -185,8 +185,6 @@ TOKENIZER_MAPPING = OrderedDict(
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(BartConfig, (BartTokenizer, BartTokenizerFast)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(RobertaConfig, (BertweetTokenizer, None)),
(RobertaConfig, (PhobertTokenizer, None)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
@ -195,7 +193,6 @@ TOKENIZER_MAPPING = OrderedDict(
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
(BertConfig, (HerbertTokenizer, HerbertTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
@ -213,6 +210,16 @@ TOKENIZER_MAPPING = OrderedDict(
]
)
# For tokenizers which are not directly mapped from a config
NO_CONFIG_TOKENIZER = [
BertJapaneseTokenizer,
BertweetTokenizer,
HerbertTokenizer,
HerbertTokenizerFast,
PhobertTokenizer,
]
SLOW_TOKENIZER_MAPPING = {
k: (v[0] if v[0] is not None else v[1])
for k, v in TOKENIZER_MAPPING.items()
@ -220,6 +227,17 @@ SLOW_TOKENIZER_MAPPING = {
}
def tokenizer_class_from_name(class_name: str):
all_tokenizer_classes = (
[v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None]
+ [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None]
+ NO_CONFIG_TOKENIZER
)
for c in all_tokenizer_classes:
if c.__name__ == class_name:
return c
class AutoTokenizer:
r"""
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
@ -307,17 +325,17 @@ class AutoTokenizer:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if "bert-base-japanese" in str(pretrained_model_name_or_path):
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
use_fast = kwargs.pop("use_fast", True)
if config.tokenizer_class is not None:
tokenizer_class = None
if use_fast and not config.tokenizer_class.endswith("Fast"):
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
else:
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
tokenizer_class_candidate = config.tokenizer_class
tokenizer_class = globals().get(tokenizer_class_candidate)
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
if tokenizer_class is None:
raise ValueError(
"Tokenizer class {} does not exist or is not currently imported.".format(tokenizer_class_candidate)

View File

@ -18,6 +18,7 @@ import os
import pickle
import unittest
from transformers import AutoTokenizer
from transformers.models.bert_japanese.tokenization_bert_japanese import (
VOCAB_FILES_NAMES,
BertJapaneseTokenizer,
@ -267,3 +268,11 @@ class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestC
# 2 is for "[CLS]", 3 is for "[SEP]"
assert encoded_sentence == [2] + text + [3]
assert encoded_pair == [2] + text + [3] + text_2 + [3]
@custom_tokenizers
class AutoTokenizerCustomTest(unittest.TestCase):
def test_tokenizer_bert_japanese(self):
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)