mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Improve bert-japanese tokenizer handling (#8659)
* Make ci fail
* Try to make tests actually run?
* CI finally failing?
* Fix CI
* Revert "Fix CI"
This reverts commit ca7923be73
.
* Ooops wrong one
* one more try
* Ok ok let's move this elsewhere
* Alternative to globals() (#8667)
* Alternative to globals()
* Error is raised later so return None
* Sentencepiece not installed make some tokenizers None
* Apply Lysandre wisdom
* Slightly clearer comment?
cc @sgugger
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
eec76615f6
commit
0cc5ab1333
@ -221,7 +221,7 @@ jobs:
|
||||
run_tests_custom_tokenizers:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.6
|
||||
- image: circleci/python:3.7
|
||||
environment:
|
||||
RUN_CUSTOM_TOKENIZERS: yes
|
||||
steps:
|
||||
|
@ -185,8 +185,6 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
||||
(BartConfig, (BartTokenizer, BartTokenizerFast)),
|
||||
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
||||
(RobertaConfig, (BertweetTokenizer, None)),
|
||||
(RobertaConfig, (PhobertTokenizer, None)),
|
||||
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
||||
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
|
||||
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
|
||||
@ -195,7 +193,6 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
|
||||
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
|
||||
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
|
||||
(BertConfig, (HerbertTokenizer, HerbertTokenizerFast)),
|
||||
(BertConfig, (BertTokenizer, BertTokenizerFast)),
|
||||
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
|
||||
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
|
||||
@ -213,6 +210,16 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
# For tokenizers which are not directly mapped from a config
|
||||
NO_CONFIG_TOKENIZER = [
|
||||
BertJapaneseTokenizer,
|
||||
BertweetTokenizer,
|
||||
HerbertTokenizer,
|
||||
HerbertTokenizerFast,
|
||||
PhobertTokenizer,
|
||||
]
|
||||
|
||||
|
||||
SLOW_TOKENIZER_MAPPING = {
|
||||
k: (v[0] if v[0] is not None else v[1])
|
||||
for k, v in TOKENIZER_MAPPING.items()
|
||||
@ -220,6 +227,17 @@ SLOW_TOKENIZER_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def tokenizer_class_from_name(class_name: str):
|
||||
all_tokenizer_classes = (
|
||||
[v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None]
|
||||
+ [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None]
|
||||
+ NO_CONFIG_TOKENIZER
|
||||
)
|
||||
for c in all_tokenizer_classes:
|
||||
if c.__name__ == class_name:
|
||||
return c
|
||||
|
||||
|
||||
class AutoTokenizer:
|
||||
r"""
|
||||
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
|
||||
@ -307,17 +325,17 @@ class AutoTokenizer:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if "bert-base-japanese" in str(pretrained_model_name_or_path):
|
||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
use_fast = kwargs.pop("use_fast", True)
|
||||
|
||||
if config.tokenizer_class is not None:
|
||||
tokenizer_class = None
|
||||
if use_fast and not config.tokenizer_class.endswith("Fast"):
|
||||
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
|
||||
else:
|
||||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
||||
if tokenizer_class is None:
|
||||
tokenizer_class_candidate = config.tokenizer_class
|
||||
tokenizer_class = globals().get(tokenizer_class_candidate)
|
||||
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
||||
|
||||
if tokenizer_class is None:
|
||||
raise ValueError(
|
||||
"Tokenizer class {} does not exist or is not currently imported.".format(tokenizer_class_candidate)
|
||||
|
@ -18,6 +18,7 @@ import os
|
||||
import pickle
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
||||
VOCAB_FILES_NAMES,
|
||||
BertJapaneseTokenizer,
|
||||
@ -267,3 +268,11 @@ class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestC
|
||||
# 2 is for "[CLS]", 3 is for "[SEP]"
|
||||
assert encoded_sentence == [2] + text + [3]
|
||||
assert encoded_pair == [2] + text + [3] + text_2 + [3]
|
||||
|
||||
|
||||
@custom_tokenizers
|
||||
class AutoTokenizerCustomTest(unittest.TestCase):
|
||||
def test_tokenizer_bert_japanese(self):
|
||||
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
|
||||
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
|
||||
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)
|
||||
|
Loading…
Reference in New Issue
Block a user