Fix: Enable prefill phase key value caching of nemotron/minitron models (#34742)

* modeling nemotron kv caching bugfix

Signed-off-by: jeongin601 <0200angela@gmail.com>

* test file deleted

Signed-off-by: jeongin601 <0200angela@gmail.com>

* code refinement

Signed-off-by: jeongin601 <0200angela@gmail.com>

* remove unused variables

Signed-off-by: jeongin601 <0200angela@gmail.com>

* import block sorted

* removed deprecation warning

Signed-off-by: jeongin601 <0200angela@gmail.com>

* removed support for tuple shape past_key_values

Signed-off-by: jeongin601 <0200angela@gmail.com>

* Update conditional statement for cache initialization

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Signed-off-by: jeongin601 <0200angela@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
jeongin601 2024-11-25 17:45:35 +09:00 committed by GitHub
parent 3a8eb74668
commit 318fe25f22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,7 +24,7 @@ import torch.utils.checkpoint
from torch import Size, Tensor, nn
from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
@ -783,8 +783,14 @@ class NemotronModel(NemotronPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)