mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixing support batch_size
and num_return_Sequences
in text-generation
pipeline (#15318)
* Fixing support `batch_size` and `num_return_Sequences` in `text-generation` pipeline And `text2text-generation` too. The bug was caused by the batch_size containing both the incoming batch **and** the generated `num_sequences`. The fix simply consists into splitting both of these again into different dimensions. * TF support. * Odd backward compatibility script in the way.
This commit is contained in:
parent
c4d1fd77fa
commit
06107541d3
@ -136,7 +136,11 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
|
||||
result = super().__call__(*args, **kwargs)
|
||||
if isinstance(args[0], list) and all(isinstance(el, str) for el in args[0]):
|
||||
if (
|
||||
isinstance(args[0], list)
|
||||
and all(isinstance(el, str) for el in args[0])
|
||||
and all(len(res) == 1 for res in result)
|
||||
):
|
||||
return [res[0] for res in result]
|
||||
return result
|
||||
|
||||
@ -146,19 +150,24 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
if self.framework == "pt":
|
||||
input_length = model_inputs["input_ids"].shape[-1]
|
||||
in_b, input_length = model_inputs["input_ids"].shape
|
||||
elif self.framework == "tf":
|
||||
input_length = tf.shape(model_inputs["input_ids"])[-1].numpy()
|
||||
in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
|
||||
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
|
||||
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
||||
out_b = output_ids.shape[0]
|
||||
if self.framework == "pt":
|
||||
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
|
||||
elif self.framework == "tf":
|
||||
output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
|
||||
return {"output_ids": output_ids}
|
||||
|
||||
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
|
||||
records = []
|
||||
for output_ids in model_outputs["output_ids"]:
|
||||
for output_ids in model_outputs["output_ids"][0]:
|
||||
if return_type == ReturnType.TENSORS:
|
||||
record = {f"{self.return_name}_token_ids": model_outputs}
|
||||
elif return_type == ReturnType.TEXT:
|
||||
|
@ -2,10 +2,14 @@ import enum
|
||||
|
||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
from ..file_utils import add_end_docstrings
|
||||
from ..file_utils import add_end_docstrings, is_tf_available
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class ReturnType(enum.Enum):
|
||||
TENSORS = 0
|
||||
NEW_TEXT = 1
|
||||
@ -202,23 +206,29 @@ class TextGenerationPipeline(Pipeline):
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
prompt_text = model_inputs.pop("prompt_text")
|
||||
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt":
|
||||
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == "tf":
|
||||
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
||||
|
||||
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
|
||||
generated_sequence = model_outputs["generated_sequence"]
|
||||
generated_sequence = model_outputs["generated_sequence"][0]
|
||||
input_ids = model_outputs["input_ids"]
|
||||
prompt_text = model_outputs["prompt_text"]
|
||||
if self.framework == "pt" and generated_sequence is not None:
|
||||
generated_sequence = generated_sequence.cpu()
|
||||
generated_sequence = generated_sequence.numpy().tolist()
|
||||
if return_type == ReturnType.TENSORS:
|
||||
record = {"generated_token_ids": generated_sequence}
|
||||
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
||||
# Decode text
|
||||
record = []
|
||||
for sequence in generated_sequence:
|
||||
records = []
|
||||
for sequence in generated_sequence:
|
||||
if return_type == ReturnType.TENSORS:
|
||||
record = {"generated_token_ids": generated_sequence}
|
||||
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
||||
# Decode text
|
||||
text = self.tokenizer.decode(
|
||||
sequence,
|
||||
skip_special_tokens=True,
|
||||
@ -242,7 +252,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
else:
|
||||
all_text = text[prompt_length:]
|
||||
|
||||
item = {"generated_text": all_text}
|
||||
record.append(item)
|
||||
record = {"generated_text": all_text}
|
||||
records.append(record)
|
||||
|
||||
return record
|
||||
return records
|
||||
|
@ -40,6 +40,26 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
# These are encoder decoder, they don't just append to incoming string
|
||||
self.assertFalse(outputs[0]["generated_text"].startswith("Something there"))
|
||||
|
||||
outputs = generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
outputs = generator(
|
||||
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
generator(4)
|
||||
|
||||
|
@ -113,6 +113,27 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
|
||||
self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))
|
||||
|
||||
outputs = text_generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
if text_generator.tokenizer.pad_token is not None:
|
||||
outputs = text_generator(
|
||||
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
# Empty prompt is slighly special
|
||||
# it requires BOS token to exist.
|
||||
# Special case for Pegasus which will always append EOS so will
|
||||
|
Loading…
Reference in New Issue
Block a user