Avoid incorrect generations for KV caches containing more than sliding_window tokens

This commit is contained in:
Tim Beyer 2025-05-15 16:45:47 +02:00 committed by GitHub
parent 27ef46e846
commit 57b7c9ffb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -593,7 +593,13 @@ class Gemma3TextModel(Gemma2Model):
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
if past_key_values is not None:
past_seen_tokens = past_key_values.get_seq_length()
if past_seen_tokens == past_key_values.config.sliding_window - 1:
raise ValueError("You must provide cache_position when using KV cache with more than sliding_window tokens.")
else:
past_seen_tokens = 0
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],