Fixing return type tensor with num_return_sequences>1. (#16828)

* Fixing return type tensor with `num_return_sequences>1`.

* Nit.
This commit is contained in:
Nicolas Patry 2022-04-20 16:11:51 +02:00 committed by GitHub
parent ff06b17791
commit e13a91fe60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 2 deletions

View File

@ -168,7 +168,7 @@ class Text2TextGenerationPipeline(Pipeline):
records = []
for output_ids in model_outputs["output_ids"][0]:
if return_type == ReturnType.TENSORS:
record = {f"{self.return_name}_token_ids": model_outputs}
record = {f"{self.return_name}_token_ids": output_ids}
elif return_type == ReturnType.TEXT:
record = {
f"{self.return_name}_text": self.tokenizer.decode(

View File

@ -226,7 +226,7 @@ class TextGenerationPipeline(Pipeline):
records = []
for sequence in generated_sequence:
if return_type == ReturnType.TENSORS:
record = {"generated_token_ids": generated_sequence}
record = {"generated_token_ids": sequence}
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
# Decode text
text = self.tokenizer.decode(

View File

@ -21,10 +21,15 @@ from transformers import (
pipeline,
)
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch
from transformers.utils import is_torch_available
from .test_pipelines_common import ANY, PipelineTestCaseMeta
if is_torch_available():
import torch
@is_pipeline_test
class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
@ -83,6 +88,37 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
]
self.assertEqual(outputs, target_outputs)
outputs = generator("This is a test", do_sample=True, num_return_sequences=2, return_tensors=True)
self.assertEqual(
outputs,
[
{"generated_token_ids": ANY(torch.Tensor)},
{"generated_token_ids": ANY(torch.Tensor)},
],
)
generator.tokenizer.pad_token_id = generator.model.config.eos_token_id
generator.tokenizer.pad_token = "<pad>"
outputs = generator(
["This is a test", "This is a second test"],
do_sample=True,
num_return_sequences=2,
batch_size=2,
return_tensors=True,
)
self.assertEqual(
outputs,
[
[
{"generated_token_ids": ANY(torch.Tensor)},
{"generated_token_ids": ANY(torch.Tensor)},
],
[
{"generated_token_ids": ANY(torch.Tensor)},
{"generated_token_ids": ANY(torch.Tensor)},
],
],
)
@require_tf
def test_small_model_tf(self):
generator = pipeline("text2text-generation", model="patrickvonplaten/t5-tiny-random", framework="tf")

View File

@ -56,6 +56,37 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
],
)
outputs = text_generator("This is a test", do_sample=True, num_return_sequences=2, return_tensors=True)
self.assertEqual(
outputs,
[
{"generated_token_ids": ANY(list)},
{"generated_token_ids": ANY(list)},
],
)
text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id
text_generator.tokenizer.pad_token = "<pad>"
outputs = text_generator(
["This is a test", "This is a second test"],
do_sample=True,
num_return_sequences=2,
batch_size=2,
return_tensors=True,
)
self.assertEqual(
outputs,
[
[
{"generated_token_ids": ANY(list)},
{"generated_token_ids": ANY(list)},
],
[
{"generated_token_ids": ANY(list)},
{"generated_token_ids": ANY(list)},
],
],
)
@require_tf
def test_small_model_tf(self):
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")