mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixing a hard to trigger bug for text-generation
pipeline. (#18131)
* Fixing a bug where attention mask was not passed to generate. * Fixing zero-size prompts. * Comment on top.
This commit is contained in:
parent
8581a798c0
commit
fca66ec4ef
@ -205,14 +205,17 @@ class TextGenerationPipeline(Pipeline):
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
prompt_text = model_inputs.pop("prompt_text")
|
||||
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
|
||||
# BS x SL
|
||||
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt":
|
||||
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
|
Loading…
Reference in New Issue
Block a user