mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add back the position ids (#32554)
* add back the position ids * fix failing test
This commit is contained in:
parent
f3c8b18053
commit
c215523528
@ -818,6 +818,7 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
@ -858,6 +859,7 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
|
||||
output_hidden_states = True
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -913,13 +915,17 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
|
||||
if past_length > 0:
|
||||
position_ids = position_ids[:, past_length:]
|
||||
|
||||
if inputs_embeds is not None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids[:, past_length:].contiguous()}
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
if cache_position is not None:
|
||||
cache_position = cache_position[-position_ids.shape[1] :]
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user