diff --git a/src/transformers/tokenization_bert_japanese.py b/src/transformers/tokenization_bert_japanese.py index aaf82c54b32..d370e0dced0 100644 --- a/src/transformers/tokenization_bert_japanese.py +++ b/src/transformers/tokenization_bert_japanese.py @@ -89,6 +89,7 @@ class BertJapaneseTokenizer(BertTokenizer): pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]", + mecab_kwargs=None, **kwargs ): """Constructs a MecabBertTokenizer. @@ -106,6 +107,7 @@ class BertJapaneseTokenizer(BertTokenizer): Type of word tokenizer. **subword_tokenizer_type**: (`optional`) string (default "wordpiece") Type of subword tokenizer. + **mecab_kwargs**: (`optional`) dict passed to `MecabTokenizer` constructor (default None) """ super(BertTokenizer, self).__init__( unk_token=unk_token, @@ -134,7 +136,9 @@ class BertJapaneseTokenizer(BertTokenizer): do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=False ) elif word_tokenizer_type == "mecab": - self.word_tokenizer = MecabTokenizer(do_lower_case=do_lower_case, never_split=never_split) + self.word_tokenizer = MecabTokenizer( + do_lower_case=do_lower_case, never_split=never_split, **(mecab_kwargs or {}) + ) else: raise ValueError("Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type)) @@ -164,7 +168,7 @@ class BertJapaneseTokenizer(BertTokenizer): class MecabTokenizer(object): """Runs basic tokenization with MeCab morphological parser.""" - def __init__(self, do_lower_case=False, never_split=None, normalize_text=True): + def __init__(self, do_lower_case=False, never_split=None, normalize_text=True, mecab_option=None): """Constructs a MecabTokenizer. Args: @@ -176,6 +180,7 @@ class MecabTokenizer(object): List of token not to split. **normalize_text**: (`optional`) boolean (default True) Whether to apply unicode normalization to text before tokenization. + **mecab_option**: (`optional`) string passed to `MeCab.Tagger` constructor (default "") """ self.do_lower_case = do_lower_case self.never_split = never_split if never_split is not None else [] @@ -183,7 +188,7 @@ class MecabTokenizer(object): import MeCab - self.mecab = MeCab.Tagger() + self.mecab = MeCab.Tagger(mecab_option) if mecab_option is not None else MeCab.Tagger() def tokenize(self, text, never_split=None, **kwargs): """Tokenizes a piece of text.""" diff --git a/tests/test_tokenization_bert_japanese.py b/tests/test_tokenization_bert_japanese.py index 4900ff49da5..4e0925d7296 100644 --- a/tests/test_tokenization_bert_japanese.py +++ b/tests/test_tokenization_bert_japanese.py @@ -91,6 +91,20 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ["アップルストア", "で", "iphone", "8", "が", "発売", "さ", "れ", "た", "。"], ) + def test_mecab_tokenizer_with_option(self): + try: + tokenizer = MecabTokenizer( + do_lower_case=True, normalize_text=False, mecab_option="-d /usr/local/lib/mecab/dic/jumandic" + ) + except RuntimeError: + # if dict doesn't exist in the system, previous code raises this error. + return + + self.assertListEqual( + tokenizer.tokenize(" \tアップルストアでiPhone8 が \n 発売された 。 "), + ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れた", "\u3000", "。"], + ) + def test_mecab_tokenizer_no_normalize(self): tokenizer = MecabTokenizer(normalize_text=False)