Fix edge case for continue_final_message (#36404)

* Fix edge case for continue_final_message

* lstrip() correctly

* Add regression test

* Add a clearer error message when the final message is not present

* Add a clearer error message when the final message is not present

* Fix massive bug!
This commit is contained in:
Matt 2025-03-03 18:03:03 +00:00 committed by GitHub
parent 2aff938992
commit 1975be4d97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 7 deletions

View File

@ -1702,13 +1702,30 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if continue_final_message:
final_message = chat[-1]["content"]
if isinstance(final_message, (list, tuple)):
final_message = final_message[-1]["text"]
try:
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)]
except: # noqa: E722
# Some chat templates like Llama-3.1 trim messages before rendering, so we must do the same here.
final_message = final_message.strip()
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)]
for content_block in reversed(final_message):
if "text" in content_block:
# Pick the last text block in the message (the first one we hit while iterating in reverse)
final_message = content_block["text"]
break
else:
raise ValueError(
"continue_final_message is set but we could not find any text to continue"
"in the final message!"
)
if final_message.strip() not in rendered_chat:
raise ValueError(
"continue_final_message is set but the final message does not appear in the chat after "
"applying the chat template! This can happen if the chat template deletes portions of "
"the final message. Please verify the chat template and final message in your chat to "
"ensure they are compatible."
)
final_msg_loc = rendered_chat.rindex(final_message.strip())
if rendered_chat[final_msg_loc : final_msg_loc + len(final_message.lstrip())] == final_message:
# The template preserves spacing or the message doesn't have trailing spacing, so things are simple
rendered_chat = rendered_chat[: final_msg_loc + len(final_message.lstrip())]
else:
# The message has trailing spacing that was trimmed, so we must be more cautious
rendered_chat = rendered_chat[: final_msg_loc + len(final_message.strip())]
rendered.append(rendered_chat)
if not is_batched:

View File

@ -1565,6 +1565,33 @@ class TokenizerTesterMixin:
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
)
@require_jinja
def test_continue_final_message_with_decoy_earlier_message(self):
"""Regression test for chat templates where an earlier message has similar content to the final message
https://github.com/huggingface/transformers/issues/35433"""
dummy_template = """
{%- for message in messages %}
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] | trim + "<|im_end|>" + "\n"}}
{%- endfor %}"""
dummy_conversation = [
{"role": "user", "content": "hi 0"},
{"role": "assistant", "content": "bye: 0"},
{"role": "user", "content": "hi 1"},
{"role": "assistant", "content": "bye: "},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
prefill_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
)
# Assert that the final message is unterminated
self.assertEqual(
prefill_output,
"<|im_start|>user\nhi 0<|im_end|>\n<|im_start|>assistant\nbye: 0<|im_end|>\n<|im_start|>user\nhi 1<|im_end|>\n<|im_start|>assistant\nbye:",
)
@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"