mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix: Enable prefill phase key value caching of nemotron/minitron models (#34742)
* modeling nemotron kv caching bugfix Signed-off-by: jeongin601 <0200angela@gmail.com> * test file deleted Signed-off-by: jeongin601 <0200angela@gmail.com> * code refinement Signed-off-by: jeongin601 <0200angela@gmail.com> * remove unused variables Signed-off-by: jeongin601 <0200angela@gmail.com> * import block sorted * removed deprecation warning Signed-off-by: jeongin601 <0200angela@gmail.com> * removed support for tuple shape past_key_values Signed-off-by: jeongin601 <0200angela@gmail.com> * Update conditional statement for cache initialization Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Signed-off-by: jeongin601 <0200angela@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
3a8eb74668
commit
318fe25f22
@ -24,7 +24,7 @@ import torch.utils.checkpoint
|
||||
from torch import Size, Tensor, nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
@ -783,8 +783,14 @@ class NemotronModel(NemotronPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user