Fixing a hard to trigger bug for text-generation pipeline. (#18131)

* Fixing a bug where attention mask was not passed to generate.

* Fixing zero-size prompts.

* Comment on top.
This commit is contained in:
Nicolas Patry 2022-07-15 15:54:07 +02:00 committed by GitHub
parent 8581a798c0
commit fca66ec4ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -205,14 +205,17 @@ class TextGenerationPipeline(Pipeline):
def _forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
# BS x SL
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])