Fixing support batch_size and num_return_Sequences in text-generation pipeline (#15318)

* Fixing support `batch_size` and `num_return_Sequences` in
`text-generation` pipeline

And `text2text-generation` too.

The bug was caused by the batch_size containing both the incoming batch
**and** the generated `num_sequences`.

The fix simply consists into splitting both of these again into
different dimensions.

* TF support.

* Odd backward compatibility script in the way.
This commit is contained in:
Nicolas Patry 2022-01-28 12:15:30 +01:00 committed by GitHub
parent c4d1fd77fa
commit 06107541d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 17 deletions

View File

@ -136,7 +136,11 @@ class Text2TextGenerationPipeline(Pipeline):
"""
result = super().__call__(*args, **kwargs)
if isinstance(args[0], list) and all(isinstance(el, str) for el in args[0]):
if (
isinstance(args[0], list)
and all(isinstance(el, str) for el in args[0])
and all(len(res) == 1 for res in result)
):
return [res[0] for res in result]
return result
@ -146,19 +150,24 @@ class Text2TextGenerationPipeline(Pipeline):
def _forward(self, model_inputs, **generate_kwargs):
if self.framework == "pt":
input_length = model_inputs["input_ids"].shape[-1]
in_b, input_length = model_inputs["input_ids"].shape
elif self.framework == "tf":
input_length = tf.shape(model_inputs["input_ids"])[-1].numpy()
in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
out_b = output_ids.shape[0]
if self.framework == "pt":
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
elif self.framework == "tf":
output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
return {"output_ids": output_ids}
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
records = []
for output_ids in model_outputs["output_ids"]:
for output_ids in model_outputs["output_ids"][0]:
if return_type == ReturnType.TENSORS:
record = {f"{self.return_name}_token_ids": model_outputs}
elif return_type == ReturnType.TEXT:

View File

@ -2,10 +2,14 @@ import enum
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
from ..file_utils import add_end_docstrings
from ..file_utils import add_end_docstrings, is_tf_available
from .base import PIPELINE_INIT_ARGS, Pipeline
if is_tf_available():
import tensorflow as tf
class ReturnType(enum.Enum):
TENSORS = 0
NEW_TEXT = 1
@ -202,23 +206,29 @@ class TextGenerationPipeline(Pipeline):
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
in_b = 1
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf":
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
generated_sequence = model_outputs["generated_sequence"]
generated_sequence = model_outputs["generated_sequence"][0]
input_ids = model_outputs["input_ids"]
prompt_text = model_outputs["prompt_text"]
if self.framework == "pt" and generated_sequence is not None:
generated_sequence = generated_sequence.cpu()
generated_sequence = generated_sequence.numpy().tolist()
if return_type == ReturnType.TENSORS:
record = {"generated_token_ids": generated_sequence}
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
# Decode text
record = []
for sequence in generated_sequence:
records = []
for sequence in generated_sequence:
if return_type == ReturnType.TENSORS:
record = {"generated_token_ids": generated_sequence}
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
# Decode text
text = self.tokenizer.decode(
sequence,
skip_special_tokens=True,
@ -242,7 +252,7 @@ class TextGenerationPipeline(Pipeline):
else:
all_text = text[prompt_length:]
item = {"generated_text": all_text}
record.append(item)
record = {"generated_text": all_text}
records.append(record)
return record
return records

View File

@ -40,6 +40,26 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
# These are encoder decoder, they don't just append to incoming string
self.assertFalse(outputs[0]["generated_text"].startswith("Something there"))
outputs = generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
self.assertEqual(
outputs,
[
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
],
)
outputs = generator(
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
)
self.assertEqual(
outputs,
[
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
],
)
with self.assertRaises(ValueError):
generator(4)

View File

@ -113,6 +113,27 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))
outputs = text_generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
self.assertEqual(
outputs,
[
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
],
)
if text_generator.tokenizer.pad_token is not None:
outputs = text_generator(
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
)
self.assertEqual(
outputs,
[
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
],
)
# Empty prompt is slighly special
# it requires BOS token to exist.
# Special case for Pegasus which will always append EOS so will