diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 01d50c01660..bf993ad2f31 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -964,8 +964,8 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel): attention_mask: Optional[torch.Tensor] = None, **kwargs, ): - if input_ids.shape[1] == 0: - past_len = inputs_embeds.shape[1] + if inputs_embeds is not None: + past_len = inputs_embeds.shape[1] + input_ids.shape[1] else: past_len = input_ids.shape[1] if use_cache: