fix past_key_values in GPTNeoXForCausalLM.prepare_inputs_for_generation (#20621)

* fix past_key_values in GPTNeoXForCausalLM.prepare_inputs_for_generation

* fix formatting
This commit is contained in:
ValeKnappich 2022-12-21 12:46:04 +01:00 committed by GitHub
parent 852e7ebaa2
commit 2da82bb4a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -697,7 +697,11 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
if past and past[0] is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past or model_kwargs.get("past_key_values"),
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()