This commit is contained in:
thomwolf 2019-04-03 10:51:03 +02:00
parent 846b1fd6f8
commit 1d8c232324

View File

@ -105,13 +105,13 @@ class BertTokenizer(object):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text): def tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
split_tokens = [] for token in self.basic_tokenizer.tokenize(text):
for token in self.basic_tokenizer.tokenize(text): for sub_token in self.wordpiece_tokenizer.tokenize(token):
for sub_token in self.wordpiece_tokenizer.tokenize(token): split_tokens.append(sub_token)
split_tokens.append(sub_token)
else: else:
split_tokens = self.wordpiece_tokenizer.tokenize(text) split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
@ -142,6 +142,16 @@ class BertTokenizer(object):
""" """
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior.")
kwargs['do_lower_case'] = False
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True
else: else:
vocab_file = pretrained_model_name_or_path vocab_file = pretrained_model_name_or_path
if os.path.isdir(vocab_file): if os.path.isdir(vocab_file):