From 2ba040a71f26987477d56fecf451ce92340ce0af Mon Sep 17 00:00:00 2001 From: Dmitry Tarasov <15786586+mrsndmn@users.noreply.github.com> Date: Tue, 4 Feb 2025 12:27:52 +0300 Subject: [PATCH] apply_chat_template: consistent behaviour for return_assistant_tokens_mask=True return_tensors=True (#35582) * apply_chat_template: consistent return_tensors behaviour with return_assistant_tokens_mask flag * test_chat_template_return_assistant_tokens_mask: support tokenizers with no attention mask * test_chat_template_return_assistant_tokens_mask: skip tokenizers with no padding token * test_chat_template_return_assistant_tokens_mask: force tokenizer padding_side=right --------- Co-authored-by: Eduard Allakhverdov Co-authored-by: d.tarasov --- src/transformers/tokenization_utils_base.py | 10 +++- tests/test_tokenization_common.py | 52 +++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 1e0d76ad5c5..7ad36ab017f 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1742,7 +1742,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])): current_mask[token_id] = 1 assistant_masks.append(current_mask) - out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0] + + if not is_batched and not return_tensors: + assistant_masks = assistant_masks[0] + + out["assistant_masks"] = assistant_masks + + if return_tensors: + out.convert_to_tensors(tensor_type=return_tensors) + return out else: return out["input_ids"] diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 9bf90efd4b5..d1dc9cc2024 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -62,6 +62,7 @@ from transformers.tokenization_utils import AddedToken if is_torch_available(): + import torch import torch.nn as nn @@ -1219,6 +1220,7 @@ class TokenizerTesterMixin: self.assertEqual(len(strftime_output), 10) self.assertEqual(len(strftime_output.split("-")), 3) + @require_torch @require_jinja def test_chat_template_return_assistant_tokens_mask(self): dummy_template = ( @@ -1263,6 +1265,9 @@ class TokenizerTesterMixin: self.skipTest(reason="No fast tokenizer defined") tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name) + self._check_no_pad_token_padding(tokenizer_r, conversations) + + tokenizer_r.padding_side = "right" # check batched output = tokenizer_r.apply_chat_template( @@ -1272,6 +1277,20 @@ class TokenizerTesterMixin: return_assistant_tokens_mask=True, return_dict=True, ) + + output_pt = tokenizer_r.apply_chat_template( + conversations, + chat_template=dummy_template, + tokenize=True, + padding=True, + return_assistant_tokens_mask=True, + return_dict=True, + return_tensors="pt", + ) + + self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor) + self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape) + for i, conv in enumerate(conversations): chat_string = tokenizer_r.apply_chat_template( conversations[i], tokenize=False, chat_template=dummy_template @@ -1297,18 +1316,30 @@ class TokenizerTesterMixin: output["assistant_masks"][i][assistant_start : assistant_end + 1], [1] * (assistant_end - assistant_start + 1), ) + self.assertTrue( + (output_pt["assistant_masks"][i, assistant_start : assistant_end + 1] == 1).all(), + ) + # assert 1 second assistant message self.assertEqual( output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1], [1] * (assistant_end2 - assistant_start2 + 1), ) + self.assertTrue( + (output_pt["assistant_masks"][i, assistant_start2 : assistant_end2 + 1] == 1).all(), + ) # assert 0 in user/system indices self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start) + self.assertTrue((output_pt["assistant_masks"][i, :assistant_start] == 0).all()) + self.assertEqual( output["assistant_masks"][i][assistant_end + 1 : assistant_start2], [0] * (assistant_start2 - assistant_end - 1), ) + self.assertTrue( + (output_pt["assistant_masks"][i, assistant_end + 1 : assistant_start2] == 0).all(), + ) # check not batched output = tokenizer_r.apply_chat_template( @@ -1318,6 +1349,17 @@ class TokenizerTesterMixin: return_assistant_tokens_mask=True, return_dict=True, ) + output_pt = tokenizer_r.apply_chat_template( + conversations[0], + chat_template=dummy_template, + tokenize=True, + return_assistant_tokens_mask=True, + return_dict=True, + return_tensors="pt", + ) + + self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor) + self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape) chat_string = tokenizer_r.apply_chat_template( conversations[0], tokenize=False, chat_template=dummy_template @@ -1336,17 +1378,27 @@ class TokenizerTesterMixin: output["assistant_masks"][assistant_start : assistant_end + 1], [1] * (assistant_end - assistant_start + 1), ) + self.assertTrue( + (output_pt["assistant_masks"][assistant_start : assistant_end + 1] == 1).all(), + ) self.assertEqual( output["assistant_masks"][assistant_start2 : assistant_end2 + 1], [1] * (assistant_end2 - assistant_start2 + 1), ) + self.assertTrue( + (output_pt["assistant_masks"][assistant_start2 : assistant_end2 + 1] == 1).all(), + ) # assert 0 in user/system indices self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start) + self.assertTrue((output_pt["assistant_masks"][0, :assistant_start] == 0).all()) self.assertEqual( output["assistant_masks"][assistant_end + 1 : assistant_start2], [0] * (assistant_start2 - assistant_end - 1), ) + self.assertTrue( + (output_pt["assistant_masks"][0, assistant_end + 1 : assistant_start2] == 0).all(), + ) @require_jinja def test_chat_template_return_assistant_tokens_mask_truncated(self):