Add inputs_embeds functionality when generating with BioGPT (#21889)

* initial commit to add inputs_embeds to generation

* formatting
This commit is contained in:
Sid Kiblawi 2023-03-02 04:43:19 -08:00 committed by GitHub
parent 3412f5979d
commit edbb37f736
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -703,17 +703,27 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask, past_key_values=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs
):
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):