add back the position ids (#32554)

* add back the position ids

* fix failing test
This commit is contained in:
Arthur 2024-08-16 11:00:05 +02:00 committed by GitHub
parent f3c8b18053
commit c215523528
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -818,6 +818,7 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
@ -858,6 +859,7 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
output_hidden_states = True
outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
cache_position=cache_position,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
@ -913,13 +915,17 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
if past_length > 0:
position_ids = position_ids[:, past_length:]
if inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]}
else:
model_inputs = {"input_ids": input_ids[:, past_length:].contiguous()}
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if cache_position is not None:
cache_position = cache_position[-position_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
model_inputs.update(
{