Backward compatibility fix for the Conversation class (#27176)

* Backward compatibility fix for the Conversation class

* Explain what's going on in the conditional
This commit is contained in:
Matt 2023-10-31 15:12:06 +00:00 committed by GitHub
parent 309a90664f
commit 05f2290114
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 9 deletions

View File

@ -54,6 +54,7 @@ class Conversation:
# This block deals with the legacy args - new code should just totally
# avoid past_user_inputs and generated_responses
self._num_processed_user_inputs = 0
generated_responses = deprecated_kwargs.pop("generated_responses", None)
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
if generated_responses is not None and past_user_inputs is None:
@ -114,10 +115,11 @@ class Conversation:
def mark_processed(self):
"""
This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between
processed and unprocessed user input.
This is a legacy method, as the Conversation no longer distinguishes between processed and unprocessed user
input. We set a counter here to keep behaviour mostly backward-compatible, but in general you should just read
the messages directly when writing new code.
"""
pass
self._num_processed_user_inputs = len(self._user_messages)
def __iter__(self):
for message in self.messages:
@ -163,7 +165,17 @@ class Conversation:
@property
def past_user_inputs(self):
# This is a legacy property for backwards compatibility. It is recommended to just directly access
# conversation.messages instead.
# conversation.messages instead. The modern class does not care about which messages are "processed"
# or not.
if not self._user_messages:
return []
# In the past, the most recent user message had to be mark_processed() before being included
# in past_user_messages. The class essentially had a single-message buffer, representing messages that
# had not yet been replied to. This is no longer the case, but we mimic the behaviour in this property
# for backward compatibility.
if self.messages[-1]["role"] != "user" or self._num_processed_user_inputs == len(self._user_messages):
return self._user_messages
return self._user_messages[:-1]
@property

View File

@ -136,8 +136,8 @@ class ConversationalPipelineTests(unittest.TestCase):
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
@ -167,7 +167,7 @@ class ConversationalPipelineTests(unittest.TestCase):
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=torch_device)
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
# When
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
# Then
@ -375,8 +375,8 @@ These are just a few of the many attractions that Paris has to offer. With so mu
conversation_1 = Conversation("My name is Sarah and I live in London")
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then