mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[HybridCache] Fix get_seq_length
method (#31661)
* fix gemma2 * handle in generate
This commit is contained in:
parent
464aa74659
commit
1c68f2cafb
@ -1083,7 +1083,7 @@ class HybridCache(Cache):
|
||||
# no matter how long the sentence is
|
||||
return self.max_cache_len
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
|
@ -1399,7 +1399,7 @@ class GenerationMixin:
|
||||
cache = model_kwargs["past_key_values"]
|
||||
if not isinstance(cache, Cache):
|
||||
past_length = cache[0][0].shape[2]
|
||||
elif hasattr(cache, "get_seq_length"):
|
||||
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
|
||||
past_length = cache.get_seq_length()
|
||||
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
|
Loading…
Reference in New Issue
Block a user