[HybridCache] Fix get_seq_length method (#31661)

* fix gemma2

* handle in generate
This commit is contained in:
Sanchit Gandhi 2024-06-27 18:40:40 +01:00 committed by GitHub
parent 464aa74659
commit 1c68f2cafb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -1083,7 +1083,7 @@ class HybridCache(Cache):
# no matter how long the sentence is
return self.max_cache_len
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0):
return None
def reset(self):

View File

@ -1399,7 +1399,7 @@ class GenerationMixin:
cache = model_kwargs["past_key_values"]
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length"):
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()
if "inputs_embeds" in model_kwargs: