diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index e077333c193..4f85e2bdd64 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -196,7 +196,8 @@ class CommonTestCases: if tokenizer.add_special_tokens_sentences_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer": seq_0 = "Test this method." seq_1 = "With these inputs." - sequences, mask = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, output_mask=True) + information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True, output_mask=True) + sequences, mask = information["sequence"], information["mask"] assert len(sequences) == len(mask) def test_number_of_added_tokens(self): @@ -210,7 +211,7 @@ class CommonTestCases: # Method is implemented (e.g. not GPT-2) if len(attached_sequences) != 2: - assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - sum([len(seq) for seq in sequences]) + assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - len(sequences) def test_maximum_encoding_length_single_input(self): tokenizer = self.get_tokenizer() @@ -220,8 +221,12 @@ class CommonTestCases: sequence = tokenizer.encode(seq_0) num_added_tokens = tokenizer.num_added_tokens() total_length = len(sequence) + num_added_tokens - truncated_sequence = tokenizer.encode(seq_0, max_length=total_length - 2, add_special_tokens=True) + information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True) + truncated_sequence = information["sequence"] + overflowing_tokens = information["overflowing_tokens"] + + assert len(overflowing_tokens) == 2 assert len(truncated_sequence) == total_length - 2 assert truncated_sequence == tokenizer.add_special_tokens_single_sentence(sequence[:-2]) @@ -236,7 +241,10 @@ class CommonTestCases: tokenizer.encode(seq_0), tokenizer.encode(seq_1)[:-2] ) - truncated_sequence = tokenizer.encode(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True) + information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True) + + truncated_sequence = information["sequence"] + overflowing_tokens = information["overflowing_tokens"] assert len(truncated_sequence) == len(sequence) - 2 assert truncated_sequence == truncated_second_sequence