add kwargs to base encode function

This commit is contained in:
thomwolf 2019-08-27 14:05:59 +02:00
parent f1b018740c
commit a175a9dc01

View File

@ -563,7 +563,7 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False):
def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
@ -574,15 +574,16 @@ class PreTrainedTokenizer(object):
text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
**kwargs: passed to the `self.tokenize()` method
"""
if text_pair is None:
if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
else:
return self.convert_tokens_to_ids(self.tokenize(text))
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)]
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)