support chat generator as input of TextGenerationPipeline (#35551)

* support chat generator as input of TextGenerationPipeline

* missing import

* fix tests

* again

* simpler

* add test
This commit is contained in:
Quentin Lhoest 2025-01-08 13:27:07 +01:00 committed by GitHub
parent ebdd1ad400
commit 3fde88b19d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 66 additions and 9 deletions

View File

@ -1,4 +1,6 @@
import enum
import itertools
import types
import warnings
from typing import Dict
@ -260,16 +262,27 @@ class TextGenerationPipeline(Pipeline):
ids of the generated text.
"""
if isinstance(
text_inputs, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple)
) and isinstance(text_inputs[0], (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(text_inputs[0], dict):
return super().__call__(Chat(text_inputs), **kwargs)
text_inputs,
(list, tuple, types.GeneratorType, KeyDataset)
if is_torch_available()
else (list, tuple, types.GeneratorType),
):
if isinstance(text_inputs, types.GeneratorType):
text_inputs, _ = itertools.tee(text_inputs)
text_inputs, first_item = (x for x in text_inputs), next(_)
else:
chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈
return super().__call__(chats, **kwargs)
else:
return super().__call__(text_inputs, **kwargs)
first_item = text_inputs[0]
if isinstance(first_item, (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(first_item, dict):
return super().__call__(Chat(text_inputs), **kwargs)
else:
chats = (Chat(chat) for chat in text_inputs) # 🐈 🐈 🐈
if isinstance(text_inputs, types.GeneratorType):
return super().__call__(chats, **kwargs)
else:
return super().__call__(list(chats), **kwargs)
return super().__call__(text_inputs, **kwargs)
def preprocess(
self,

View File

@ -292,6 +292,50 @@ class TextGenerationPipelineTests(unittest.TestCase):
],
)
@require_torch
def test_small_chat_model_with_iterator_pt(self):
from transformers.pipelines.pt_utils import PipelineIterator
text_generator = pipeline(
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
]
chat2 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a second test"},
]
expected_chat1 = chat1 + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
}
]
expected_chat2 = chat2 + [
{
"role": "assistant",
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
}
]
def data():
yield from [chat1, chat2]
outputs = text_generator(data(), do_sample=False, max_new_tokens=10)
assert isinstance(outputs, PipelineIterator)
outputs = list(outputs)
self.assertEqual(
outputs,
[
[{"generated_text": expected_chat1}],
[{"generated_text": expected_chat2}],
],
)
@require_tf
def test_small_model_tf(self):
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")