apply_chat_template: consistent behaviour for return_assistant_tokens_mask=True return_tensors=True (#35582)

* apply_chat_template: consistent return_tensors behaviour with return_assistant_tokens_mask flag

* test_chat_template_return_assistant_tokens_mask: support tokenizers with no attention mask

* test_chat_template_return_assistant_tokens_mask: skip tokenizers with no padding token

* test_chat_template_return_assistant_tokens_mask: force tokenizer padding_side=right

---------

Co-authored-by: Eduard Allakhverdov <goncharova@airi.net>
Co-authored-by: d.tarasov <d.tarasov@airi.net>
This commit is contained in:
Dmitry Tarasov 2025-02-04 12:27:52 +03:00 committed by GitHub
parent 9c02cb6233
commit 2ba040a71f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 1 deletions

View File

@ -1742,7 +1742,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]
if not is_batched and not return_tensors:
assistant_masks = assistant_masks[0]
out["assistant_masks"] = assistant_masks
if return_tensors:
out.convert_to_tensors(tensor_type=return_tensors)
return out
else:
return out["input_ids"]

View File

@ -62,6 +62,7 @@ from transformers.tokenization_utils import AddedToken
if is_torch_available():
import torch
import torch.nn as nn
@ -1219,6 +1220,7 @@ class TokenizerTesterMixin:
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)
@require_torch
@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
@ -1263,6 +1265,9 @@ class TokenizerTesterMixin:
self.skipTest(reason="No fast tokenizer defined")
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
self._check_no_pad_token_padding(tokenizer_r, conversations)
tokenizer_r.padding_side = "right"
# check batched
output = tokenizer_r.apply_chat_template(
@ -1272,6 +1277,20 @@ class TokenizerTesterMixin:
return_assistant_tokens_mask=True,
return_dict=True,
)
output_pt = tokenizer_r.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
padding=True,
return_assistant_tokens_mask=True,
return_dict=True,
return_tensors="pt",
)
self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor)
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)
for i, conv in enumerate(conversations):
chat_string = tokenizer_r.apply_chat_template(
conversations[i], tokenize=False, chat_template=dummy_template
@ -1297,18 +1316,30 @@ class TokenizerTesterMixin:
output["assistant_masks"][i][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][i, assistant_start : assistant_end + 1] == 1).all(),
)
# assert 1 second assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][i, assistant_start2 : assistant_end2 + 1] == 1).all(),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
self.assertTrue((output_pt["assistant_masks"][i, :assistant_start] == 0).all())
self.assertEqual(
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
self.assertTrue(
(output_pt["assistant_masks"][i, assistant_end + 1 : assistant_start2] == 0).all(),
)
# check not batched
output = tokenizer_r.apply_chat_template(
@ -1318,6 +1349,17 @@ class TokenizerTesterMixin:
return_assistant_tokens_mask=True,
return_dict=True,
)
output_pt = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
return_tensors="pt",
)
self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor)
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)
chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
@ -1336,17 +1378,27 @@ class TokenizerTesterMixin:
output["assistant_masks"][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][assistant_start : assistant_end + 1] == 1).all(),
)
self.assertEqual(
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][assistant_start2 : assistant_end2 + 1] == 1).all(),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
self.assertTrue((output_pt["assistant_masks"][0, :assistant_start] == 0).all())
self.assertEqual(
output["assistant_masks"][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
self.assertTrue(
(output_pt["assistant_masks"][0, assistant_end + 1 : assistant_start2] == 0).all(),
)
@require_jinja
def test_chat_template_return_assistant_tokens_mask_truncated(self):