From 61d89c19d853c1e2a597f8dc2bc609988cf2ba4c Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:06:07 +0200 Subject: [PATCH] Fix: Mamba2 generation mismatch between input_ids and inputs_embeds (#32694) * fix cache when using input embeddings * simplify check, we can always add input ids seq len since its 0 in first pass --- src/transformers/models/mamba2/modeling_mamba2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 01d50c01660..bf993ad2f31 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -964,8 +964,8 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel): attention_mask: Optional[torch.Tensor] = None, **kwargs, ): - if input_ids.shape[1] == 0: - past_len = inputs_embeds.shape[1] + if inputs_embeds is not None: + past_len = inputs_embeds.shape[1] + input_ids.shape[1] else: past_len = input_ids.shape[1] if use_cache: