mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fixed error with encoding
This commit is contained in:
parent
634a3172d8
commit
3d87991f60
@ -81,11 +81,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
|
||||
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
||||
|
||||
assert encoded_sentence == [0] + text + [2]
|
||||
assert encoded_pair == [0] + text + [2, 2] + text_2 + [2]
|
||||
assert encoded_sentence == encoded_text_from_decode
|
||||
assert encoded_pair == encoded_pair_from_decode
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -519,24 +519,19 @@ class PreTrainedTokenizer(object):
|
||||
def _convert_token_to_id(self, token):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self, text, add_special_tokens=False, *sequences):
|
||||
def encode(self, text, text_pair=None, add_special_tokens=False):
|
||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
|
||||
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
|
||||
"""
|
||||
|
||||
if len(sequences) == 0:
|
||||
if text_pair is None:
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
|
||||
else:
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
if len(sequences) > 1:
|
||||
logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the "
|
||||
"initial two.")
|
||||
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(sequences[0])]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)]
|
||||
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
||||
|
Loading…
Reference in New Issue
Block a user