mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix past_key_values in GPTNeoXForCausalLM.prepare_inputs_for_generation (#20621)
* fix past_key_values in GPTNeoXForCausalLM.prepare_inputs_for_generation * fix formatting
This commit is contained in:
parent
852e7ebaa2
commit
2da82bb4a7
@ -697,7 +697,11 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
||||
if past and past[0] is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past or model_kwargs.get("past_key_values"),
|
||||
}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
|
Loading…
Reference in New Issue
Block a user