diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py index ddc643a360f..00dbc117a4c 100644 --- a/src/transformers/models/llama/tokenization_llama.py +++ b/src/transformers/models/llama/tokenization_llama.py @@ -356,6 +356,17 @@ class LlamaTokenizer(PreTrainedTokenizer): `List[int]`: Input ids for the conversation. """ + if len(conversation.past_user_inputs) > 0: + if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: + conversation.past_user_inputs[0] = ( + B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] + ) + elif conversation.new_user_input: + if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: + conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input + else: + raise ValueError("Last message must be from user") + dialogue = list(conversation.iter_texts()) if not all([is_user for is_user, msg in dialogue[::2]]) or not all( [not is_user for is_user, msg in dialogue[1::2]] @@ -365,14 +376,6 @@ class LlamaTokenizer(PreTrainedTokenizer): ) dialog_tokens: List[int] = [] - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]: - dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1]) - dialog_tokens += sum( [ [self.bos_token_id] @@ -384,8 +387,6 @@ class LlamaTokenizer(PreTrainedTokenizer): ], [], ) - if not (dialogue[-1][0]): - raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}") dialog_tokens += [self.bos_token_id] + self.encode( f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False ) diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py index c04e2da114c..82dfbe8925f 100644 --- a/src/transformers/models/llama/tokenization_llama_fast.py +++ b/src/transformers/models/llama/tokenization_llama_fast.py @@ -212,6 +212,17 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): `List[int]`: Input ids for the conversation. """ + if len(conversation.past_user_inputs) > 0: + if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: + conversation.past_user_inputs[0] = ( + B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] + ) + elif conversation.new_user_input: + if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: + conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input + else: + raise ValueError("Last message must be from user") + dialogue = list(conversation.iter_texts()) if not all([is_user for is_user, msg in dialogue[::2]]) or not all( [not is_user for is_user, msg in dialogue[1::2]] @@ -221,14 +232,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ) dialog_tokens = [] - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]: - dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1]) - dialog_tokens += sum( [ [self.bos_token_id] @@ -240,8 +243,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ], [], ) - if not (dialogue[-1][0]): - raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}") dialog_tokens += [self.bos_token_id] + self.encode( f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False )