Fixed DistilBERT tokenizer

This commit is contained in:
LysandreJik 2019-09-24 09:41:14 -04:00
parent d340e2329e
commit 72402d1acd
2 changed files with 4 additions and 9 deletions

View File

@ -39,8 +39,10 @@ class DistilBertTokenizationTest(BertTokenizationTest):
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == text
assert encoded_pair == text + [102] + text_2
assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id]
assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \
text_2 + [tokenizer.sep_token_id]
if __name__ == '__main__':
unittest.main()

View File

@ -60,10 +60,3 @@ class DistilBertTokenizer(BertTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def add_special_tokens_single_sequence(self, token_ids):
return token_ids
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
sep = [self.sep_token_id]
return token_ids_0 + sep + token_ids_1