mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Adding num_return_sequences
support for text2text generation. (#14988)
* Adding `num_return_sequences` support for text2text generation. Co-Authored-By: Enze <pu.miao@foxmail.com> * Update tests/test_pipelines_text2text_generation.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_pipelines_text2text_generation.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Enze <pu.miao@foxmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
c043ce6cfd
commit
f8a989cfb2
@ -157,18 +157,20 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
return {"output_ids": output_ids}
|
||||
|
||||
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
|
||||
record = {}
|
||||
if return_type == ReturnType.TENSORS:
|
||||
record = {f"{self.return_name}_token_ids": model_outputs}
|
||||
elif return_type == ReturnType.TEXT:
|
||||
record = {
|
||||
f"{self.return_name}_text": self.tokenizer.decode(
|
||||
model_outputs["output_ids"][0],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
}
|
||||
return record
|
||||
records = []
|
||||
for output_ids in model_outputs["output_ids"]:
|
||||
if return_type == ReturnType.TENSORS:
|
||||
record = {f"{self.return_name}_token_ids": model_outputs}
|
||||
elif return_type == ReturnType.TEXT:
|
||||
record = {
|
||||
f"{self.return_name}_text": self.tokenizer.decode(
|
||||
output_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
}
|
||||
records.append(record)
|
||||
return records
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
|
@ -50,6 +50,19 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
outputs = generator("Something there", do_sample=False)
|
||||
self.assertEqual(outputs, [{"generated_text": ""}])
|
||||
|
||||
num_return_sequences = 3
|
||||
outputs = generator(
|
||||
"Something there",
|
||||
num_return_sequences=num_return_sequences,
|
||||
num_beams=num_return_sequences,
|
||||
)
|
||||
target_outputs = [
|
||||
{"generated_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide"},
|
||||
{"generated_text": "Beide Beide Beide Beide Beide Beide Beide Beide"},
|
||||
{"generated_text": ""},
|
||||
]
|
||||
self.assertEqual(outputs, target_outputs)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
generator = pipeline("text2text-generation", model="patrickvonplaten/t5-tiny-random", framework="tf")
|
||||
|
Loading…
Reference in New Issue
Block a user