Adding num_return_sequences support for text2text generation. (#14988)

* Adding `num_return_sequences` support for text2text generation.

Co-Authored-By: Enze <pu.miao@foxmail.com>

* Update tests/test_pipelines_text2text_generation.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/test_pipelines_text2text_generation.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Enze <pu.miao@foxmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicolas Patry 2021-12-30 16:17:15 +01:00 committed by GitHub
parent c043ce6cfd
commit f8a989cfb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 12 deletions

View File

@ -157,18 +157,20 @@ class Text2TextGenerationPipeline(Pipeline):
return {"output_ids": output_ids}
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
record = {}
if return_type == ReturnType.TENSORS:
record = {f"{self.return_name}_token_ids": model_outputs}
elif return_type == ReturnType.TEXT:
record = {
f"{self.return_name}_text": self.tokenizer.decode(
model_outputs["output_ids"][0],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
}
return record
records = []
for output_ids in model_outputs["output_ids"]:
if return_type == ReturnType.TENSORS:
record = {f"{self.return_name}_token_ids": model_outputs}
elif return_type == ReturnType.TEXT:
record = {
f"{self.return_name}_text": self.tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
}
records.append(record)
return records
@add_end_docstrings(PIPELINE_INIT_ARGS)

View File

@ -50,6 +50,19 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
outputs = generator("Something there", do_sample=False)
self.assertEqual(outputs, [{"generated_text": ""}])
num_return_sequences = 3
outputs = generator(
"Something there",
num_return_sequences=num_return_sequences,
num_beams=num_return_sequences,
)
target_outputs = [
{"generated_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide"},
{"generated_text": "Beide Beide Beide Beide Beide Beide Beide Beide"},
{"generated_text": ""},
]
self.assertEqual(outputs, target_outputs)
@require_tf
def test_small_model_tf(self):
generator = pipeline("text2text-generation", model="patrickvonplaten/t5-tiny-random", framework="tf")