Offloaded hybrid cache for Llama4 (#37401)

* first try (maybe race condition)

* Update cache_utils.py

* cannot avoid the race condition -> use 2 layers

* Update cache_utils.py

* Update cache_utils.py
This commit is contained in:
Cyril Vallez 2025-04-10 11:44:34 +02:00 committed by GitHub
parent 6d8b0b3378
commit fbb2054ed5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 1 deletions

View File

@ -2011,6 +2011,118 @@ class HybridChunkedCache(Cache):
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
class OffloadedHybridCache(HybridChunkedCache):
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: Optional[int] = None,
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.bfloat16,
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
):
super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map)
self.offload_device = torch.device(offload_device)
# Create new CUDA stream for parallel prefetching.
self._prefetch_stream = torch.cuda.Stream() if torch._C._get_accelerator().type == "cuda" else None
# Those will be dynamically created as the other layers (for TP)
self.device_key_cache = None
self.device_value_cache = None
# This gives the index of which on-device full layer to use (we need 2 to avoid race conditions when prefetching)
self.active_device_layer = 0
def initialise_cache_layer(self, layer_idx, key_states):
"""Overriden to use the correct device if offloaded layer (and pin memory)."""
if len(self.key_cache) > layer_idx:
return
num_key_value_heads = key_states.shape[1]
device = key_states.device if self.is_sliding[layer_idx] else self.offload_device
pin_memory = not self.is_sliding[layer_idx]
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (
self.max_batch_size,
num_key_value_heads,
self.sliding_window,
self.head_dim,
)
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
# Make sure to initialize the on-device layer if it does not already exist
if self.device_key_cache is None and not self.is_sliding[layer_idx]:
self.device_key_cache = []
self.device_value_cache = []
# We need 2 layers to avoid race conditions when prefetching the next one
for _ in range(2):
device_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device)
device_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.device_key_cache.append(device_layer_key_cache)
self.device_value_cache.append(device_layer_value_cache)
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
# Wait for prefetch stream if needed
if self._prefetch_stream is not None:
torch.cuda.default_stream(key_states.device).wait_stream(self._prefetch_stream)
# Get correct on-device layer
k_out = self.device_key_cache[self.active_device_layer]
v_out = self.device_value_cache[self.active_device_layer]
# Let's prefetch the next layer as soon as possible
self._prefetch_next_layer(layer_idx)
# Copy to on-device layer
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# Copy to offloaded device
self.key_cache[layer_idx][:, :, cache_position] = key_states.to(self.offload_device)
self.value_cache[layer_idx][:, :, cache_position] = value_states.to(self.offload_device)
return k_out, v_out
def _prefetch_next_layer(self, layer_idx: int) -> None:
"""Based on current layer_idx, prefetch next full layer to the device."""
# Switch the active layer
self.active_device_layer = 0 if self.active_device_layer == 1 else 1
# Find the next non-sliding layer
try:
next_layer = layer_idx + 1 + self.is_sliding[layer_idx + 1 :].index(False)
# In this case, we are at the last layer, and we go back to prefect the first one
except ValueError:
next_layer = self.is_sliding.index(False)
# Alternate between two on-device caches.
if self._prefetch_stream is not None:
with torch.cuda.stream(self._prefetch_stream):
self._prefetch_layer_in_context(next_layer)
else:
self._prefetch_layer_in_context(next_layer)
def _prefetch_layer_in_context(self, layer_idx: int) -> None:
"""Performs the actual copy of the layer to device cache."""
if len(self.key_cache) >= layer_idx:
self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True)
self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True)
# The layer was not yet initialized
else:
self.device_key_cache[self.active_device_layer].fill_(0.0)
self.device_value_cache[self.active_device_layer].fill_(0.0)
class MambaCache:
"""
Cache for mamba model which does not have attention mechanism and key value states.

View File

@ -54,6 +54,7 @@ if is_torch_available():
HybridCache,
HybridChunkedCache,
MambaCache,
OffloadedHybridCache,
OffloadedStaticCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
@ -71,6 +72,8 @@ if is_torch_available():
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"hybrid_chunked": HybridChunkedCache,
"offloaded_hybrid": OffloadedHybridCache,
"offloaded_hybrid_chunked": OffloadedHybridCache,
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}

View File

@ -35,6 +35,7 @@ from ..cache_utils import (
EncoderDecoderCache,
HybridChunkedCache,
OffloadedCache,
OffloadedHybridCache,
QuantizedCacheConfig,
StaticCache,
)
@ -1834,7 +1835,9 @@ class GenerationMixin:
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != batch_size
or isinstance(cache_to_check, HybridChunkedCache) # due to internal slicing, we always re-init
or isinstance(
cache_to_check, (HybridChunkedCache, OffloadedHybridCache)
) # due to internal slicing, we always re-init
)
if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len