Fix HybridChunedCache & Llama4 (#38299)

* Update cache_utils.py

* Update cache_utils.py
This commit is contained in:
Cyril Vallez 2025-05-22 17:25:51 +02:00 committed by GitHub
parent d95c864a25
commit 73286d8e29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1967,7 +1967,8 @@ class HybridChunkedCache(Cache):
else:
self.sliding_window = config.sliding_window
self.max_cache_len = max_cache_len
self._sliding_window_max_len = min(self.sliding_window, max_cache_len)
# Sliding layers can't be larger than the overall max cache len
self.sliding_window = min(config.sliding_window, self.max_cache_len)
self.max_batch_size = max_batch_size
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self._dtype = dtype
@ -1989,7 +1990,7 @@ class HybridChunkedCache(Cache):
num_key_value_heads = key_states.shape[1]
device = key_states.device
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim)
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
@ -2163,7 +2164,7 @@ class OffloadedHybridCache(HybridChunkedCache):
device = key_states.device if self.is_sliding[layer_idx] else self.offload_device
pin_memory = not self.is_sliding[layer_idx]
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim)
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape