Fix: Mamba2 generation mismatch between input_ids and inputs_embeds (#32694)

* fix cache when using input embeddings

* simplify check, we can always add input ids seq len since its 0 in first pass
This commit is contained in:
Anton Vlasjuk 2024-08-19 16:06:07 +02:00 committed by GitHub
parent 93e538ae2e
commit 61d89c19d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -964,8 +964,8 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
): ):
if input_ids.shape[1] == 0: if inputs_embeds is not None:
past_len = inputs_embeds.shape[1] past_len = inputs_embeds.shape[1] + input_ids.shape[1]
else: else:
past_len = input_ids.shape[1] past_len = input_ids.shape[1]
if use_cache: if use_cache: