mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
support chat generator as input of TextGenerationPipeline (#35551)
* support chat generator as input of TextGenerationPipeline * missing import * fix tests * again * simpler * add test
This commit is contained in:
parent
ebdd1ad400
commit
3fde88b19d
@ -1,4 +1,6 @@
|
||||
import enum
|
||||
import itertools
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
@ -260,16 +262,27 @@ class TextGenerationPipeline(Pipeline):
|
||||
ids of the generated text.
|
||||
"""
|
||||
if isinstance(
|
||||
text_inputs, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple)
|
||||
) and isinstance(text_inputs[0], (list, tuple, dict)):
|
||||
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
||||
if isinstance(text_inputs[0], dict):
|
||||
return super().__call__(Chat(text_inputs), **kwargs)
|
||||
text_inputs,
|
||||
(list, tuple, types.GeneratorType, KeyDataset)
|
||||
if is_torch_available()
|
||||
else (list, tuple, types.GeneratorType),
|
||||
):
|
||||
if isinstance(text_inputs, types.GeneratorType):
|
||||
text_inputs, _ = itertools.tee(text_inputs)
|
||||
text_inputs, first_item = (x for x in text_inputs), next(_)
|
||||
else:
|
||||
chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈
|
||||
return super().__call__(chats, **kwargs)
|
||||
else:
|
||||
return super().__call__(text_inputs, **kwargs)
|
||||
first_item = text_inputs[0]
|
||||
if isinstance(first_item, (list, tuple, dict)):
|
||||
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
||||
if isinstance(first_item, dict):
|
||||
return super().__call__(Chat(text_inputs), **kwargs)
|
||||
else:
|
||||
chats = (Chat(chat) for chat in text_inputs) # 🐈 🐈 🐈
|
||||
if isinstance(text_inputs, types.GeneratorType):
|
||||
return super().__call__(chats, **kwargs)
|
||||
else:
|
||||
return super().__call__(list(chats), **kwargs)
|
||||
return super().__call__(text_inputs, **kwargs)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
|
@ -292,6 +292,50 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_with_iterator_pt(self):
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
]
|
||||
chat2 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a second test"},
|
||||
]
|
||||
expected_chat1 = chat1 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
}
|
||||
]
|
||||
expected_chat2 = chat2 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
}
|
||||
]
|
||||
|
||||
def data():
|
||||
yield from [chat1, chat2]
|
||||
|
||||
outputs = text_generator(data(), do_sample=False, max_new_tokens=10)
|
||||
assert isinstance(outputs, PipelineIterator)
|
||||
outputs = list(outputs)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": expected_chat1}],
|
||||
[{"generated_text": expected_chat2}],
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
|
||||
|
Loading…
Reference in New Issue
Block a user