mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Offloaded hybrid cache for Llama4 (#37401)
* first try (maybe race condition) * Update cache_utils.py * cannot avoid the race condition -> use 2 layers * Update cache_utils.py * Update cache_utils.py
This commit is contained in:
parent
6d8b0b3378
commit
fbb2054ed5
@ -2011,6 +2011,118 @@ class HybridChunkedCache(Cache):
|
||||
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
|
||||
|
||||
|
||||
class OffloadedHybridCache(HybridChunkedCache):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
max_cache_len: Optional[int] = None,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
):
|
||||
super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map)
|
||||
self.offload_device = torch.device(offload_device)
|
||||
# Create new CUDA stream for parallel prefetching.
|
||||
self._prefetch_stream = torch.cuda.Stream() if torch._C._get_accelerator().type == "cuda" else None
|
||||
# Those will be dynamically created as the other layers (for TP)
|
||||
self.device_key_cache = None
|
||||
self.device_value_cache = None
|
||||
# This gives the index of which on-device full layer to use (we need 2 to avoid race conditions when prefetching)
|
||||
self.active_device_layer = 0
|
||||
|
||||
def initialise_cache_layer(self, layer_idx, key_states):
|
||||
"""Overriden to use the correct device if offloaded layer (and pin memory)."""
|
||||
if len(self.key_cache) > layer_idx:
|
||||
return
|
||||
|
||||
num_key_value_heads = key_states.shape[1]
|
||||
device = key_states.device if self.is_sliding[layer_idx] else self.offload_device
|
||||
pin_memory = not self.is_sliding[layer_idx]
|
||||
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (
|
||||
self.max_batch_size,
|
||||
num_key_value_heads,
|
||||
self.sliding_window,
|
||||
self.head_dim,
|
||||
)
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
# Make sure to initialize the on-device layer if it does not already exist
|
||||
if self.device_key_cache is None and not self.is_sliding[layer_idx]:
|
||||
self.device_key_cache = []
|
||||
self.device_value_cache = []
|
||||
# We need 2 layers to avoid race conditions when prefetching the next one
|
||||
for _ in range(2):
|
||||
device_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device)
|
||||
device_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.device_key_cache.append(device_layer_key_cache)
|
||||
self.device_value_cache.append(device_layer_value_cache)
|
||||
|
||||
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
# Wait for prefetch stream if needed
|
||||
if self._prefetch_stream is not None:
|
||||
torch.cuda.default_stream(key_states.device).wait_stream(self._prefetch_stream)
|
||||
|
||||
# Get correct on-device layer
|
||||
k_out = self.device_key_cache[self.active_device_layer]
|
||||
v_out = self.device_value_cache[self.active_device_layer]
|
||||
|
||||
# Let's prefetch the next layer as soon as possible
|
||||
self._prefetch_next_layer(layer_idx)
|
||||
|
||||
# Copy to on-device layer
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
# Copy to offloaded device
|
||||
self.key_cache[layer_idx][:, :, cache_position] = key_states.to(self.offload_device)
|
||||
self.value_cache[layer_idx][:, :, cache_position] = value_states.to(self.offload_device)
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
def _prefetch_next_layer(self, layer_idx: int) -> None:
|
||||
"""Based on current layer_idx, prefetch next full layer to the device."""
|
||||
|
||||
# Switch the active layer
|
||||
self.active_device_layer = 0 if self.active_device_layer == 1 else 1
|
||||
|
||||
# Find the next non-sliding layer
|
||||
try:
|
||||
next_layer = layer_idx + 1 + self.is_sliding[layer_idx + 1 :].index(False)
|
||||
# In this case, we are at the last layer, and we go back to prefect the first one
|
||||
except ValueError:
|
||||
next_layer = self.is_sliding.index(False)
|
||||
|
||||
# Alternate between two on-device caches.
|
||||
if self._prefetch_stream is not None:
|
||||
with torch.cuda.stream(self._prefetch_stream):
|
||||
self._prefetch_layer_in_context(next_layer)
|
||||
else:
|
||||
self._prefetch_layer_in_context(next_layer)
|
||||
|
||||
def _prefetch_layer_in_context(self, layer_idx: int) -> None:
|
||||
"""Performs the actual copy of the layer to device cache."""
|
||||
if len(self.key_cache) >= layer_idx:
|
||||
self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True)
|
||||
self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True)
|
||||
# The layer was not yet initialized
|
||||
else:
|
||||
self.device_key_cache[self.active_device_layer].fill_(0.0)
|
||||
self.device_value_cache[self.active_device_layer].fill_(0.0)
|
||||
|
||||
|
||||
class MambaCache:
|
||||
"""
|
||||
Cache for mamba model which does not have attention mechanism and key value states.
|
||||
|
@ -54,6 +54,7 @@ if is_torch_available():
|
||||
HybridCache,
|
||||
HybridChunkedCache,
|
||||
MambaCache,
|
||||
OffloadedHybridCache,
|
||||
OffloadedStaticCache,
|
||||
QuantizedCacheConfig,
|
||||
QuantoQuantizedCache,
|
||||
@ -71,6 +72,8 @@ if is_torch_available():
|
||||
"sliding_window": SlidingWindowCache,
|
||||
"hybrid": HybridCache,
|
||||
"hybrid_chunked": HybridChunkedCache,
|
||||
"offloaded_hybrid": OffloadedHybridCache,
|
||||
"offloaded_hybrid_chunked": OffloadedHybridCache,
|
||||
"mamba": MambaCache,
|
||||
}
|
||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||
|
@ -35,6 +35,7 @@ from ..cache_utils import (
|
||||
EncoderDecoderCache,
|
||||
HybridChunkedCache,
|
||||
OffloadedCache,
|
||||
OffloadedHybridCache,
|
||||
QuantizedCacheConfig,
|
||||
StaticCache,
|
||||
)
|
||||
@ -1834,7 +1835,9 @@ class GenerationMixin:
|
||||
not hasattr(self, "_cache")
|
||||
or (not isinstance(cache_to_check, cache_cls))
|
||||
or cache_to_check.max_batch_size != batch_size
|
||||
or isinstance(cache_to_check, HybridChunkedCache) # due to internal slicing, we always re-init
|
||||
or isinstance(
|
||||
cache_to_check, (HybridChunkedCache, OffloadedHybridCache)
|
||||
) # due to internal slicing, we always re-init
|
||||
)
|
||||
if cache_implementation != "mamba":
|
||||
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
|
||||
|
Loading…
Reference in New Issue
Block a user