mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Fix edge case for continue_final_message (#36404)
* Fix edge case for continue_final_message * lstrip() correctly * Add regression test * Add a clearer error message when the final message is not present * Add a clearer error message when the final message is not present * Fix massive bug!
This commit is contained in:
parent
2aff938992
commit
1975be4d97
@ -1702,13 +1702,30 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
if continue_final_message:
|
||||
final_message = chat[-1]["content"]
|
||||
if isinstance(final_message, (list, tuple)):
|
||||
final_message = final_message[-1]["text"]
|
||||
try:
|
||||
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)]
|
||||
except: # noqa: E722
|
||||
# Some chat templates like Llama-3.1 trim messages before rendering, so we must do the same here.
|
||||
final_message = final_message.strip()
|
||||
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)]
|
||||
for content_block in reversed(final_message):
|
||||
if "text" in content_block:
|
||||
# Pick the last text block in the message (the first one we hit while iterating in reverse)
|
||||
final_message = content_block["text"]
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"continue_final_message is set but we could not find any text to continue"
|
||||
"in the final message!"
|
||||
)
|
||||
if final_message.strip() not in rendered_chat:
|
||||
raise ValueError(
|
||||
"continue_final_message is set but the final message does not appear in the chat after "
|
||||
"applying the chat template! This can happen if the chat template deletes portions of "
|
||||
"the final message. Please verify the chat template and final message in your chat to "
|
||||
"ensure they are compatible."
|
||||
)
|
||||
final_msg_loc = rendered_chat.rindex(final_message.strip())
|
||||
if rendered_chat[final_msg_loc : final_msg_loc + len(final_message.lstrip())] == final_message:
|
||||
# The template preserves spacing or the message doesn't have trailing spacing, so things are simple
|
||||
rendered_chat = rendered_chat[: final_msg_loc + len(final_message.lstrip())]
|
||||
else:
|
||||
# The message has trailing spacing that was trimmed, so we must be more cautious
|
||||
rendered_chat = rendered_chat[: final_msg_loc + len(final_message.strip())]
|
||||
rendered.append(rendered_chat)
|
||||
|
||||
if not is_batched:
|
||||
|
@ -1565,6 +1565,33 @@ class TokenizerTesterMixin:
|
||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_continue_final_message_with_decoy_earlier_message(self):
|
||||
"""Regression test for chat templates where an earlier message has similar content to the final message
|
||||
https://github.com/huggingface/transformers/issues/35433"""
|
||||
|
||||
dummy_template = """
|
||||
{%- for message in messages %}
|
||||
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] | trim + "<|im_end|>" + "\n"}}
|
||||
{%- endfor %}"""
|
||||
dummy_conversation = [
|
||||
{"role": "user", "content": "hi 0"},
|
||||
{"role": "assistant", "content": "bye: 0"},
|
||||
{"role": "user", "content": "hi 1"},
|
||||
{"role": "assistant", "content": "bye: "},
|
||||
]
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
prefill_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
|
||||
)
|
||||
# Assert that the final message is unterminated
|
||||
self.assertEqual(
|
||||
prefill_output,
|
||||
"<|im_start|>user\nhi 0<|im_end|>\n<|im_start|>assistant\nbye: 0<|im_end|>\n<|im_start|>user\nhi 1<|im_end|>\n<|im_start|>assistant\nbye:",
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict(self):
|
||||
dummy_template_1 = "{{'a'}}"
|
||||
|
Loading…
Reference in New Issue
Block a user