diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index c549e06d78e..bbb3e25fc79 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -105,13 +105,13 @@ class BertTokenizer(object): self.max_len = max_len if max_len is not None else int(1e12) def tokenize(self, text): + split_tokens = [] if self.do_basic_tokenize: - split_tokens = [] - for token in self.basic_tokenizer.tokenize(text): - for sub_token in self.wordpiece_tokenizer.tokenize(token): - split_tokens.append(sub_token) + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) else: - split_tokens = self.wordpiece_tokenizer.tokenize(text) + split_tokens = self.wordpiece_tokenizer.tokenize(text) return split_tokens def convert_tokens_to_ids(self, tokens): @@ -142,6 +142,16 @@ class BertTokenizer(object): """ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): + logger.warning("The pre-trained model you are loading is a cased model but you have not set " + "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " + "you may want to check this behavior.") + kwargs['do_lower_case'] = False + elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): + logger.warning("The pre-trained model you are loading is an uncased model but you have set " + "`do_lower_case` to False. We are setting `do_lower_case=True` for you " + "but you may want to check this behavior.") + kwargs['do_lower_case'] = True else: vocab_file = pretrained_model_name_or_path if os.path.isdir(vocab_file):