mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
apply_chat_template: consistent behaviour for return_assistant_tokens_mask=True return_tensors=True (#35582)
* apply_chat_template: consistent return_tensors behaviour with return_assistant_tokens_mask flag * test_chat_template_return_assistant_tokens_mask: support tokenizers with no attention mask * test_chat_template_return_assistant_tokens_mask: skip tokenizers with no padding token * test_chat_template_return_assistant_tokens_mask: force tokenizer padding_side=right --------- Co-authored-by: Eduard Allakhverdov <goncharova@airi.net> Co-authored-by: d.tarasov <d.tarasov@airi.net>
This commit is contained in:
parent
9c02cb6233
commit
2ba040a71f
@ -1742,7 +1742,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
|
||||
current_mask[token_id] = 1
|
||||
assistant_masks.append(current_mask)
|
||||
out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]
|
||||
|
||||
if not is_batched and not return_tensors:
|
||||
assistant_masks = assistant_masks[0]
|
||||
|
||||
out["assistant_masks"] = assistant_masks
|
||||
|
||||
if return_tensors:
|
||||
out.convert_to_tensors(tensor_type=return_tensors)
|
||||
|
||||
return out
|
||||
else:
|
||||
return out["input_ids"]
|
||||
|
@ -62,6 +62,7 @@ from transformers.tokenization_utils import AddedToken
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@ -1219,6 +1220,7 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(len(strftime_output), 10)
|
||||
self.assertEqual(len(strftime_output.split("-")), 3)
|
||||
|
||||
@require_torch
|
||||
@require_jinja
|
||||
def test_chat_template_return_assistant_tokens_mask(self):
|
||||
dummy_template = (
|
||||
@ -1263,6 +1265,9 @@ class TokenizerTesterMixin:
|
||||
self.skipTest(reason="No fast tokenizer defined")
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
|
||||
self._check_no_pad_token_padding(tokenizer_r, conversations)
|
||||
|
||||
tokenizer_r.padding_side = "right"
|
||||
|
||||
# check batched
|
||||
output = tokenizer_r.apply_chat_template(
|
||||
@ -1272,6 +1277,20 @@ class TokenizerTesterMixin:
|
||||
return_assistant_tokens_mask=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
output_pt = tokenizer_r.apply_chat_template(
|
||||
conversations,
|
||||
chat_template=dummy_template,
|
||||
tokenize=True,
|
||||
padding=True,
|
||||
return_assistant_tokens_mask=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor)
|
||||
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)
|
||||
|
||||
for i, conv in enumerate(conversations):
|
||||
chat_string = tokenizer_r.apply_chat_template(
|
||||
conversations[i], tokenize=False, chat_template=dummy_template
|
||||
@ -1297,18 +1316,30 @@ class TokenizerTesterMixin:
|
||||
output["assistant_masks"][i][assistant_start : assistant_end + 1],
|
||||
[1] * (assistant_end - assistant_start + 1),
|
||||
)
|
||||
self.assertTrue(
|
||||
(output_pt["assistant_masks"][i, assistant_start : assistant_end + 1] == 1).all(),
|
||||
)
|
||||
|
||||
# assert 1 second assistant message
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
|
||||
[1] * (assistant_end2 - assistant_start2 + 1),
|
||||
)
|
||||
self.assertTrue(
|
||||
(output_pt["assistant_masks"][i, assistant_start2 : assistant_end2 + 1] == 1).all(),
|
||||
)
|
||||
|
||||
# assert 0 in user/system indices
|
||||
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
|
||||
self.assertTrue((output_pt["assistant_masks"][i, :assistant_start] == 0).all())
|
||||
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
|
||||
[0] * (assistant_start2 - assistant_end - 1),
|
||||
)
|
||||
self.assertTrue(
|
||||
(output_pt["assistant_masks"][i, assistant_end + 1 : assistant_start2] == 0).all(),
|
||||
)
|
||||
|
||||
# check not batched
|
||||
output = tokenizer_r.apply_chat_template(
|
||||
@ -1318,6 +1349,17 @@ class TokenizerTesterMixin:
|
||||
return_assistant_tokens_mask=True,
|
||||
return_dict=True,
|
||||
)
|
||||
output_pt = tokenizer_r.apply_chat_template(
|
||||
conversations[0],
|
||||
chat_template=dummy_template,
|
||||
tokenize=True,
|
||||
return_assistant_tokens_mask=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor)
|
||||
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)
|
||||
|
||||
chat_string = tokenizer_r.apply_chat_template(
|
||||
conversations[0], tokenize=False, chat_template=dummy_template
|
||||
@ -1336,17 +1378,27 @@ class TokenizerTesterMixin:
|
||||
output["assistant_masks"][assistant_start : assistant_end + 1],
|
||||
[1] * (assistant_end - assistant_start + 1),
|
||||
)
|
||||
self.assertTrue(
|
||||
(output_pt["assistant_masks"][assistant_start : assistant_end + 1] == 1).all(),
|
||||
)
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
|
||||
[1] * (assistant_end2 - assistant_start2 + 1),
|
||||
)
|
||||
self.assertTrue(
|
||||
(output_pt["assistant_masks"][assistant_start2 : assistant_end2 + 1] == 1).all(),
|
||||
)
|
||||
|
||||
# assert 0 in user/system indices
|
||||
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
|
||||
self.assertTrue((output_pt["assistant_masks"][0, :assistant_start] == 0).all())
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][assistant_end + 1 : assistant_start2],
|
||||
[0] * (assistant_start2 - assistant_end - 1),
|
||||
)
|
||||
self.assertTrue(
|
||||
(output_pt["assistant_masks"][0, assistant_end + 1 : assistant_start2] == 0).all(),
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||||
|
Loading…
Reference in New Issue
Block a user