mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-28 16:52:24 +06:00
Fix: Mamba2 generation mismatch between input_ids and inputs_embeds (#32694)
* fix cache when using input embeddings * simplify check, we can always add input ids seq len since its 0 in first pass
This commit is contained in:
parent
93e538ae2e
commit
61d89c19d8
@ -964,8 +964,8 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if input_ids.shape[1] == 0:
|
||||
past_len = inputs_embeds.shape[1]
|
||||
if inputs_embeds is not None:
|
||||
past_len = inputs_embeds.shape[1] + input_ids.shape[1]
|
||||
else:
|
||||
past_len = input_ids.shape[1]
|
||||
if use_cache:
|
||||
|
Loading…
Reference in New Issue
Block a user