diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 121d3eaa83a..83d5c649b5c 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1702,13 +1702,30 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): if continue_final_message: final_message = chat[-1]["content"] if isinstance(final_message, (list, tuple)): - final_message = final_message[-1]["text"] - try: - rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)] - except: # noqa: E722 - # Some chat templates like Llama-3.1 trim messages before rendering, so we must do the same here. - final_message = final_message.strip() - rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)] + for content_block in reversed(final_message): + if "text" in content_block: + # Pick the last text block in the message (the first one we hit while iterating in reverse) + final_message = content_block["text"] + break + else: + raise ValueError( + "continue_final_message is set but we could not find any text to continue" + "in the final message!" + ) + if final_message.strip() not in rendered_chat: + raise ValueError( + "continue_final_message is set but the final message does not appear in the chat after " + "applying the chat template! This can happen if the chat template deletes portions of " + "the final message. Please verify the chat template and final message in your chat to " + "ensure they are compatible." + ) + final_msg_loc = rendered_chat.rindex(final_message.strip()) + if rendered_chat[final_msg_loc : final_msg_loc + len(final_message.lstrip())] == final_message: + # The template preserves spacing or the message doesn't have trailing spacing, so things are simple + rendered_chat = rendered_chat[: final_msg_loc + len(final_message.lstrip())] + else: + # The message has trailing spacing that was trimmed, so we must be more cautious + rendered_chat = rendered_chat[: final_msg_loc + len(final_message.strip())] rendered.append(rendered_chat) if not is_batched: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index f2849d5ed70..f7263dcb818 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1565,6 +1565,33 @@ class TokenizerTesterMixin: "<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message", ) + @require_jinja + def test_continue_final_message_with_decoy_earlier_message(self): + """Regression test for chat templates where an earlier message has similar content to the final message + https://github.com/huggingface/transformers/issues/35433""" + + dummy_template = """ + {%- for message in messages %} + {{- "<|im_start|>" + message['role'] + "\n" + message['content'] | trim + "<|im_end|>" + "\n"}} + {%- endfor %}""" + dummy_conversation = [ + {"role": "user", "content": "hi 0"}, + {"role": "assistant", "content": "bye: 0"}, + {"role": "user", "content": "hi 1"}, + {"role": "assistant", "content": "bye: "}, + ] + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + prefill_output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True + ) + # Assert that the final message is unterminated + self.assertEqual( + prefill_output, + "<|im_start|>user\nhi 0<|im_end|>\n<|im_start|>assistant\nbye: 0<|im_end|>\n<|im_start|>user\nhi 1<|im_end|>\n<|im_start|>assistant\nbye:", + ) + @require_jinja def test_chat_template_dict(self): dummy_template_1 = "{{'a'}}"