mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
support chat generator as input of TextGenerationPipeline (#35551)
* support chat generator as input of TextGenerationPipeline * missing import * fix tests * again * simpler * add test
This commit is contained in:
parent
ebdd1ad400
commit
3fde88b19d
@ -1,4 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
|
import itertools
|
||||||
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
@ -260,16 +262,27 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
ids of the generated text.
|
ids of the generated text.
|
||||||
"""
|
"""
|
||||||
if isinstance(
|
if isinstance(
|
||||||
text_inputs, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple)
|
text_inputs,
|
||||||
) and isinstance(text_inputs[0], (list, tuple, dict)):
|
(list, tuple, types.GeneratorType, KeyDataset)
|
||||||
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
if is_torch_available()
|
||||||
if isinstance(text_inputs[0], dict):
|
else (list, tuple, types.GeneratorType),
|
||||||
return super().__call__(Chat(text_inputs), **kwargs)
|
):
|
||||||
|
if isinstance(text_inputs, types.GeneratorType):
|
||||||
|
text_inputs, _ = itertools.tee(text_inputs)
|
||||||
|
text_inputs, first_item = (x for x in text_inputs), next(_)
|
||||||
else:
|
else:
|
||||||
chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈
|
first_item = text_inputs[0]
|
||||||
return super().__call__(chats, **kwargs)
|
if isinstance(first_item, (list, tuple, dict)):
|
||||||
else:
|
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
||||||
return super().__call__(text_inputs, **kwargs)
|
if isinstance(first_item, dict):
|
||||||
|
return super().__call__(Chat(text_inputs), **kwargs)
|
||||||
|
else:
|
||||||
|
chats = (Chat(chat) for chat in text_inputs) # 🐈 🐈 🐈
|
||||||
|
if isinstance(text_inputs, types.GeneratorType):
|
||||||
|
return super().__call__(chats, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().__call__(list(chats), **kwargs)
|
||||||
|
return super().__call__(text_inputs, **kwargs)
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
|
@ -292,6 +292,50 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_chat_model_with_iterator_pt(self):
|
||||||
|
from transformers.pipelines.pt_utils import PipelineIterator
|
||||||
|
|
||||||
|
text_generator = pipeline(
|
||||||
|
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Using `do_sample=False` to force deterministic output
|
||||||
|
chat1 = [
|
||||||
|
{"role": "system", "content": "This is a system message."},
|
||||||
|
{"role": "user", "content": "This is a test"},
|
||||||
|
]
|
||||||
|
chat2 = [
|
||||||
|
{"role": "system", "content": "This is a system message."},
|
||||||
|
{"role": "user", "content": "This is a second test"},
|
||||||
|
]
|
||||||
|
expected_chat1 = chat1 + [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
expected_chat2 = chat2 + [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def data():
|
||||||
|
yield from [chat1, chat2]
|
||||||
|
|
||||||
|
outputs = text_generator(data(), do_sample=False, max_new_tokens=10)
|
||||||
|
assert isinstance(outputs, PipelineIterator)
|
||||||
|
outputs = list(outputs)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"generated_text": expected_chat1}],
|
||||||
|
[{"generated_text": expected_chat2}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
|
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
|
||||||
|
Loading…
Reference in New Issue
Block a user