mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 19:18:24 +06:00
Improve bert-japanese tokenizer handling (#8659)
* Make ci fail
* Try to make tests actually run?
* CI finally failing?
* Fix CI
* Revert "Fix CI"
This reverts commit ca7923be73
.
* Ooops wrong one
* one more try
* Ok ok let's move this elsewhere
* Alternative to globals() (#8667)
* Alternative to globals()
* Error is raised later so return None
* Sentencepiece not installed make some tokenizers None
* Apply Lysandre wisdom
* Slightly clearer comment?
cc @sgugger
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
eec76615f6
commit
0cc5ab1333
@ -221,7 +221,7 @@ jobs:
|
|||||||
run_tests_custom_tokenizers:
|
run_tests_custom_tokenizers:
|
||||||
working_directory: ~/transformers
|
working_directory: ~/transformers
|
||||||
docker:
|
docker:
|
||||||
- image: circleci/python:3.6
|
- image: circleci/python:3.7
|
||||||
environment:
|
environment:
|
||||||
RUN_CUSTOM_TOKENIZERS: yes
|
RUN_CUSTOM_TOKENIZERS: yes
|
||||||
steps:
|
steps:
|
||||||
|
@ -185,8 +185,6 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||||||
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
||||||
(BartConfig, (BartTokenizer, BartTokenizerFast)),
|
(BartConfig, (BartTokenizer, BartTokenizerFast)),
|
||||||
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
||||||
(RobertaConfig, (BertweetTokenizer, None)),
|
|
||||||
(RobertaConfig, (PhobertTokenizer, None)),
|
|
||||||
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
||||||
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
|
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
|
||||||
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
|
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
|
||||||
@ -195,7 +193,6 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||||||
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
|
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
|
||||||
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
|
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
|
||||||
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
|
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
|
||||||
(BertConfig, (HerbertTokenizer, HerbertTokenizerFast)),
|
|
||||||
(BertConfig, (BertTokenizer, BertTokenizerFast)),
|
(BertConfig, (BertTokenizer, BertTokenizerFast)),
|
||||||
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
|
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
|
||||||
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
|
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
|
||||||
@ -213,6 +210,16 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For tokenizers which are not directly mapped from a config
|
||||||
|
NO_CONFIG_TOKENIZER = [
|
||||||
|
BertJapaneseTokenizer,
|
||||||
|
BertweetTokenizer,
|
||||||
|
HerbertTokenizer,
|
||||||
|
HerbertTokenizerFast,
|
||||||
|
PhobertTokenizer,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
SLOW_TOKENIZER_MAPPING = {
|
SLOW_TOKENIZER_MAPPING = {
|
||||||
k: (v[0] if v[0] is not None else v[1])
|
k: (v[0] if v[0] is not None else v[1])
|
||||||
for k, v in TOKENIZER_MAPPING.items()
|
for k, v in TOKENIZER_MAPPING.items()
|
||||||
@ -220,6 +227,17 @@ SLOW_TOKENIZER_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def tokenizer_class_from_name(class_name: str):
|
||||||
|
all_tokenizer_classes = (
|
||||||
|
[v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None]
|
||||||
|
+ [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None]
|
||||||
|
+ NO_CONFIG_TOKENIZER
|
||||||
|
)
|
||||||
|
for c in all_tokenizer_classes:
|
||||||
|
if c.__name__ == class_name:
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
class AutoTokenizer:
|
class AutoTokenizer:
|
||||||
r"""
|
r"""
|
||||||
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
|
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
|
||||||
@ -307,17 +325,17 @@ class AutoTokenizer:
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if "bert-base-japanese" in str(pretrained_model_name_or_path):
|
|
||||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
|
|
||||||
use_fast = kwargs.pop("use_fast", True)
|
use_fast = kwargs.pop("use_fast", True)
|
||||||
|
|
||||||
if config.tokenizer_class is not None:
|
if config.tokenizer_class is not None:
|
||||||
|
tokenizer_class = None
|
||||||
if use_fast and not config.tokenizer_class.endswith("Fast"):
|
if use_fast and not config.tokenizer_class.endswith("Fast"):
|
||||||
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
|
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
|
||||||
else:
|
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
||||||
|
if tokenizer_class is None:
|
||||||
tokenizer_class_candidate = config.tokenizer_class
|
tokenizer_class_candidate = config.tokenizer_class
|
||||||
tokenizer_class = globals().get(tokenizer_class_candidate)
|
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
||||||
|
|
||||||
if tokenizer_class is None:
|
if tokenizer_class is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Tokenizer class {} does not exist or is not currently imported.".format(tokenizer_class_candidate)
|
"Tokenizer class {} does not exist or is not currently imported.".format(tokenizer_class_candidate)
|
||||||
|
@ -18,6 +18,7 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
from transformers.models.bert_japanese.tokenization_bert_japanese import (
|
||||||
VOCAB_FILES_NAMES,
|
VOCAB_FILES_NAMES,
|
||||||
BertJapaneseTokenizer,
|
BertJapaneseTokenizer,
|
||||||
@ -267,3 +268,11 @@ class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestC
|
|||||||
# 2 is for "[CLS]", 3 is for "[SEP]"
|
# 2 is for "[CLS]", 3 is for "[SEP]"
|
||||||
assert encoded_sentence == [2] + text + [3]
|
assert encoded_sentence == [2] + text + [3]
|
||||||
assert encoded_pair == [2] + text + [3] + text_2 + [3]
|
assert encoded_pair == [2] + text + [3] + text_2 + [3]
|
||||||
|
|
||||||
|
|
||||||
|
@custom_tokenizers
|
||||||
|
class AutoTokenizerCustomTest(unittest.TestCase):
|
||||||
|
def test_tokenizer_bert_japanese(self):
|
||||||
|
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
|
||||||
|
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)
|
||||||
|
Loading…
Reference in New Issue
Block a user