mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add inputs_embeds
functionality when generating with BioGPT (#21889)
* initial commit to add inputs_embeds to generation * formatting
This commit is contained in:
parent
3412f5979d
commit
edbb37f736
@ -703,17 +703,27 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask, past_key_values=None, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs
|
||||
):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
}
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
}
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
|
Loading…
Reference in New Issue
Block a user