diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index d23e79e5dba..7a5e7fd587c 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -249,8 +249,17 @@ class T5Tokenizer(PreTrainedTokenizer): def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ - out_string = self.sp_model.decode_pieces(tokens) - return out_string + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " " + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode_pieces(current_sub_tokens) + return out_string.strip() def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 5be6fdfdffd..05d45d9b693 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -222,3 +222,18 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(expected_src_tokens, src_ids) self.assertEqual(expected_tgt_tokens, tgt_ids) + + def test_fast_and_slow_same_result(self): + src_text = " Today is nice day " + tgt_ids = [0, 1960, 19, 2, 1245, 239, 1] + tgt_text = " Today is nice day" + + fast_ids = self.t5_base_tokenizer_fast(src_text, add_special_tokens=False).input_ids + slow_ids = self.t5_base_tokenizer(src_text, add_special_tokens=False).input_ids + self.assertEqual(tgt_ids, fast_ids) + self.assertEqual(tgt_ids, slow_ids) + + fast_text = self.t5_base_tokenizer_fast.decode(fast_ids) + slow_text = self.t5_base_tokenizer.decode(fast_ids) + self.assertEqual(tgt_text, fast_text) + self.assertEqual(tgt_text, slow_text)