mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix HybridChunedCache & Llama4 (#38299)
* Update cache_utils.py * Update cache_utils.py
This commit is contained in:
parent
d95c864a25
commit
73286d8e29
@ -1967,7 +1967,8 @@ class HybridChunkedCache(Cache):
|
||||
else:
|
||||
self.sliding_window = config.sliding_window
|
||||
self.max_cache_len = max_cache_len
|
||||
self._sliding_window_max_len = min(self.sliding_window, max_cache_len)
|
||||
# Sliding layers can't be larger than the overall max cache len
|
||||
self.sliding_window = min(config.sliding_window, self.max_cache_len)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self._dtype = dtype
|
||||
@ -1989,7 +1990,7 @@ class HybridChunkedCache(Cache):
|
||||
num_key_value_heads = key_states.shape[1]
|
||||
device = key_states.device
|
||||
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim)
|
||||
sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim)
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
|
||||
@ -2163,7 +2164,7 @@ class OffloadedHybridCache(HybridChunkedCache):
|
||||
device = key_states.device if self.is_sliding[layer_idx] else self.offload_device
|
||||
pin_memory = not self.is_sliding[layer_idx]
|
||||
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim)
|
||||
sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim)
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
|
||||
|
Loading…
Reference in New Issue
Block a user