diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c0bd42f2e39..e97cac65f25 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1967,7 +1967,8 @@ class HybridChunkedCache(Cache): else: self.sliding_window = config.sliding_window self.max_cache_len = max_cache_len - self._sliding_window_max_len = min(self.sliding_window, max_cache_len) + # Sliding layers can't be larger than the overall max cache len + self.sliding_window = min(config.sliding_window, self.max_cache_len) self.max_batch_size = max_batch_size self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self._dtype = dtype @@ -1989,7 +1990,7 @@ class HybridChunkedCache(Cache): num_key_value_heads = key_states.shape[1] device = key_states.device global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim) + sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape @@ -2163,7 +2164,7 @@ class OffloadedHybridCache(HybridChunkedCache): device = key_states.device if self.is_sliding[layer_idx] else self.offload_device pin_memory = not self.is_sliding[layer_idx] global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim) + sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape