mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Better handling missing SYS in llama conversation tokenizer (#24997)
* Better handling missing SYS in llama conversation tokenizer The existing code failed to add SYS if the conversation has history without SYS, but did modify the passed conversation as it did. Rearrange the code so modification to the conversation object are taken into account for token id generation. * Fix formatting with black * Avoid one-liners * Also fix fast tokenizer * Drop List decl
This commit is contained in:
parent
6704923107
commit
efb2ba666d
@ -356,6 +356,17 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
`List[int]`:
|
||||
Input ids for the conversation.
|
||||
"""
|
||||
if len(conversation.past_user_inputs) > 0:
|
||||
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
|
||||
conversation.past_user_inputs[0] = (
|
||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||
)
|
||||
elif conversation.new_user_input:
|
||||
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
||||
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
||||
else:
|
||||
raise ValueError("Last message must be from user")
|
||||
|
||||
dialogue = list(conversation.iter_texts())
|
||||
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
||||
[not is_user for is_user, msg in dialogue[1::2]]
|
||||
@ -365,14 +376,6 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
)
|
||||
|
||||
dialog_tokens: List[int] = []
|
||||
if len(conversation.past_user_inputs) > 0:
|
||||
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
|
||||
conversation.past_user_inputs[0] = (
|
||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||
)
|
||||
elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]:
|
||||
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1])
|
||||
|
||||
dialog_tokens += sum(
|
||||
[
|
||||
[self.bos_token_id]
|
||||
@ -384,8 +387,6 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
],
|
||||
[],
|
||||
)
|
||||
if not (dialogue[-1][0]):
|
||||
raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}")
|
||||
dialog_tokens += [self.bos_token_id] + self.encode(
|
||||
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
|
||||
)
|
||||
|
@ -212,6 +212,17 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
`List[int]`:
|
||||
Input ids for the conversation.
|
||||
"""
|
||||
if len(conversation.past_user_inputs) > 0:
|
||||
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
|
||||
conversation.past_user_inputs[0] = (
|
||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||
)
|
||||
elif conversation.new_user_input:
|
||||
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
||||
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
||||
else:
|
||||
raise ValueError("Last message must be from user")
|
||||
|
||||
dialogue = list(conversation.iter_texts())
|
||||
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
||||
[not is_user for is_user, msg in dialogue[1::2]]
|
||||
@ -221,14 +232,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
)
|
||||
|
||||
dialog_tokens = []
|
||||
if len(conversation.past_user_inputs) > 0:
|
||||
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
|
||||
conversation.past_user_inputs[0] = (
|
||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||
)
|
||||
elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]:
|
||||
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1])
|
||||
|
||||
dialog_tokens += sum(
|
||||
[
|
||||
[self.bos_token_id]
|
||||
@ -240,8 +243,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
],
|
||||
[],
|
||||
)
|
||||
if not (dialogue[-1][0]):
|
||||
raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}")
|
||||
dialog_tokens += [self.bos_token_id] + self.encode(
|
||||
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user