mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Backward compatibility fix for the Conversation class (#27176)
* Backward compatibility fix for the Conversation class * Explain what's going on in the conditional
This commit is contained in:
parent
309a90664f
commit
05f2290114
@ -54,6 +54,7 @@ class Conversation:
|
||||
|
||||
# This block deals with the legacy args - new code should just totally
|
||||
# avoid past_user_inputs and generated_responses
|
||||
self._num_processed_user_inputs = 0
|
||||
generated_responses = deprecated_kwargs.pop("generated_responses", None)
|
||||
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
|
||||
if generated_responses is not None and past_user_inputs is None:
|
||||
@ -114,10 +115,11 @@ class Conversation:
|
||||
|
||||
def mark_processed(self):
|
||||
"""
|
||||
This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between
|
||||
processed and unprocessed user input.
|
||||
This is a legacy method, as the Conversation no longer distinguishes between processed and unprocessed user
|
||||
input. We set a counter here to keep behaviour mostly backward-compatible, but in general you should just read
|
||||
the messages directly when writing new code.
|
||||
"""
|
||||
pass
|
||||
self._num_processed_user_inputs = len(self._user_messages)
|
||||
|
||||
def __iter__(self):
|
||||
for message in self.messages:
|
||||
@ -163,7 +165,17 @@ class Conversation:
|
||||
@property
|
||||
def past_user_inputs(self):
|
||||
# This is a legacy property for backwards compatibility. It is recommended to just directly access
|
||||
# conversation.messages instead.
|
||||
# conversation.messages instead. The modern class does not care about which messages are "processed"
|
||||
# or not.
|
||||
if not self._user_messages:
|
||||
return []
|
||||
# In the past, the most recent user message had to be mark_processed() before being included
|
||||
# in past_user_messages. The class essentially had a single-message buffer, representing messages that
|
||||
# had not yet been replied to. This is no longer the case, but we mimic the behaviour in this property
|
||||
# for backward compatibility.
|
||||
if self.messages[-1]["role"] != "user" or self._num_processed_user_inputs == len(self._user_messages):
|
||||
return self._user_messages
|
||||
|
||||
return self._user_messages[:-1]
|
||||
|
||||
@property
|
||||
|
@ -136,8 +136,8 @@ class ConversationalPipelineTests(unittest.TestCase):
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 1)
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
# When
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
@ -167,7 +167,7 @@ class ConversationalPipelineTests(unittest.TestCase):
|
||||
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=torch_device)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
# When
|
||||
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
@ -375,8 +375,8 @@ These are just a few of the many attractions that Paris has to offer. With so mu
|
||||
conversation_1 = Conversation("My name is Sarah and I live in London")
|
||||
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 1)
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
# When
|
||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
|
Loading…
Reference in New Issue
Block a user