New cache tests and refactored Hybrid Cache (#37972)

This commit is contained in:
Manuel de Prada Corral 2025-05-20 12:46:13 +02:00 committed by GitHub
parent 183fb3637c
commit d34e21e7dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 471 additions and 147 deletions

View File

@ -21,6 +21,104 @@ if is_hqq_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Utility functions for static/sliding cache update logic
def _static_cache_update(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_position: Optional[torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the static cache tensors in place.
Args:
k_cache (`torch.Tensor`): The key cache tensor to update.
v_cache (`torch.Tensor`): The value cache tensor to update.
key_states (`torch.Tensor`): The new key states to add.
value_states (`torch.Tensor`): The new value states to add.
cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted.
If None, the entire cache is overwritten (prefill).
Returns:
Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place).
"""
if cache_position is None:
# Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
k_cache.copy_(key_states)
v_cache.copy_(value_states)
else:
# Generation phase. Update specific positions.
# Use index_copy_ for in-place update (compile-friendly).
try:
k_cache.index_copy_(2, cache_position, key_states)
v_cache.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
k_cache[:, :, cache_position] = key_states
v_cache[:, :, cache_position] = value_states
return k_cache, v_cache
def _sliding_cache_update(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_position: torch.LongTensor,
max_cache_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the sliding window cache tensors, returning the potentially modified tensors.
Args:
k_cache (`torch.Tensor`): The key cache tensor to update.
v_cache (`torch.Tensor`): The value cache tensor to update.
key_states (`torch.Tensor`): The new key states to add.
value_states (`torch.Tensor`): The new value states to add.
cache_position (`torch.LongTensor`): The position indices where the new states should be inserted.
max_cache_len (`int`): The maximum length of the sliding window cache.
Returns:
Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update.
For prefill > window, these are the full input states.
Otherwise, they are the updated cache tensors.
"""
# Handle prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > max_cache_len:
new_k = key_states[:, :, -max_cache_len:, :]
new_v = value_states[:, :, -max_cache_len:, :]
k_cache.copy_(new_k)
v_cache.copy_(new_v)
return key_states, value_states
# Sliding window logic for generation phase or prefill < window
slicing = torch.arange(max_cache_len, device=value_states.device)
current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
to_shift = current_seq_len > max_cache_len
indices = (slicing + to_shift.sum()) % max_cache_len
k_out_shifted = k_cache[:, :, indices]
v_out_shifted = v_cache[:, :, indices]
# Clamp cache_position to determine the *target index* within the shifted cache view
update_position = cache_position.clamp(min=0, max=max_cache_len - 1)
try:
k_out_updated = k_out_shifted.index_copy(2, update_position, key_states)
v_out_updated = v_out_shifted.index_copy(2, update_position, value_states)
except NotImplementedError:
# Fallback for MPS: clone and modify the clone
k_out_updated = k_out_shifted.clone()
v_out_updated = v_out_shifted.clone()
k_out_updated[:, :, update_position] = key_states
v_out_updated[:, :, update_position] = value_states
k_cache.copy_(k_out_updated)
v_cache.copy_(v_out_updated)
return k_out_updated, v_out_updated
class Cache: class Cache:
""" """
Base, abstract class for all caches. The actual data structure is specific to each subclass. Base, abstract class for all caches. The actual data structure is specific to each subclass.
@ -1264,28 +1362,16 @@ class StaticCache(Cache):
""" """
if cache_kwargs is None: if cache_kwargs is None:
cache_kwargs = {} cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if cache_position is None: key_states = key_states.to(self.key_cache[layer_idx].dtype)
k_out.copy_(key_states) value_states = value_states.to(self.value_cache[layer_idx].dtype)
v_out.copy_(value_states) return _static_cache_update(
else: self.key_cache[layer_idx],
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to self.value_cache[layer_idx],
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place key_states,
# operation, that avoids copies and uses less memory. value_states,
try: cache_kwargs.get("cache_position"),
k_out.index_copy_(2, cache_position, key_states) )
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model.""" """Returns the sequence length of the cached states that were seen by the model."""
@ -1314,7 +1400,7 @@ class SlidingWindowCache(StaticCache):
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
@ -1398,46 +1484,21 @@ class SlidingWindowCache(StaticCache):
if cache_kwargs is None: if cache_kwargs is None:
cache_kwargs = {} cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) if cache_position is None:
if cache_position.shape[0] >= self.max_cache_len: raise ValueError("`cache_position` must be provided for SlidingWindowCache.")
k_out = key_states[:, :, -self.max_cache_len :, :]
v_out = value_states[:, :, -self.max_cache_len :, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) key_states = key_states.to(self.key_cache[layer_idx].dtype)
to_shift = cache_position > self.max_cache_len - 1 value_states = value_states.to(self.value_cache[layer_idx].dtype)
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
k_out = k_out[:, :, indices] return _sliding_cache_update(
v_out = v_out[:, :, indices] self.key_cache[layer_idx],
self.value_cache[layer_idx],
try: key_states,
k_out.index_copy_(2, cache_position, key_states) value_states,
v_out.index_copy_(2, cache_position, value_states) cache_position,
except NotImplementedError: self.max_cache_len,
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device. )
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
def get_max_cache_shape(self) -> Optional[int]: def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len return self.max_cache_len
@ -1680,12 +1741,13 @@ class HybridCache(Cache):
super().__init__() super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None: if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError( raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting " "Setting `cache_implementation` to 'hybrid' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model " "sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None." "config and it's not set to None."
) )
self.max_cache_len = max_cache_len self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings
self._sliding_window_max_len = min(config.sliding_window, max_cache_len) # Sliding layers can't be larger than the overall max cache len
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = ( self.head_dim = (
@ -1694,22 +1756,17 @@ class HybridCache(Cache):
self._dtype = dtype self._dtype = dtype
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
) )
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor( self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = ( sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
self.max_batch_size,
self.num_key_value_heads,
self._sliding_window_max_len,
self.head_dim,
)
device = torch.device(device) if device is not None else None device = torch.device(device) if device is not None else None
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
if layer_device_map is not None: if layer_device_map is not None:
@ -1718,7 +1775,7 @@ class HybridCache(Cache):
layer_device = device layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. # breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_key_cache)
@ -1726,42 +1783,6 @@ class HybridCache(Cache):
self.key_cache.append(new_layer_key_cache) self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache) self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] >= max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
to_shift = cache_position > max_cache_len - 1
cache_position = cache_position.clamp(0, max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def update( def update(
self, self,
key_states: torch.Tensor, key_states: torch.Tensor,
@ -1772,7 +1793,10 @@ class HybridCache(Cache):
if cache_kwargs is None: if cache_kwargs is None:
cache_kwargs = {} cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window") if cache_position is None:
raise ValueError("`cache_position` must be provided for HybridCache.")
is_sliding_layer = self.is_sliding_list[layer_idx]
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2) # when the cache is initialized in the forward pass (e.g. Gemma2)
@ -1781,25 +1805,22 @@ class HybridCache(Cache):
if self.value_cache[layer_idx].device != value_states.device: if self.value_cache[layer_idx].device != value_states.device:
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
k_out = self.key_cache[layer_idx] k_cache = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_cache = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype) key_states = key_states.to(k_cache.dtype)
value_states = value_states.to(v_out.dtype) value_states = value_states.to(v_cache.dtype)
if sliding_window: if is_sliding_layer:
update_fn = self._sliding_update return _sliding_cache_update(
k_cache,
v_cache,
key_states,
value_states,
cache_position,
k_cache.shape[2], # Use actual cache dim as max cache len
)
else: else:
update_fn = self._static_update return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position)
return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)
def get_max_cache_shape(self) -> Optional[int]: def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len return self.max_cache_len
@ -2033,7 +2054,7 @@ class OffloadedHybridCache(HybridChunkedCache):
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer # track of the original device of each layer
unique_devices = set(layer_device_map.values()) unique_devices = set(layer_device_map.values()) if layer_device_map else set()
if len(unique_devices) > 1: if len(unique_devices) > 1:
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}") raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")
@ -2292,7 +2313,7 @@ class OffloadedStaticCache(StaticCache):
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer # track of the original device of each layer
unique_devices = set(layer_device_map.values()) unique_devices = set(layer_device_map.values()) if layer_device_map else set()
if len(unique_devices) > 1: if len(unique_devices) > 1:
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}") raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")
@ -2369,6 +2390,9 @@ class OffloadedStaticCache(StaticCache):
A tuple containing the updated key and value states. A tuple containing the updated key and value states.
""" """
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)
if layer_idx == 0: if layer_idx == 0:
# Update seen tokens. # Update seen tokens.
# TODO(gante): Remove this. # TODO(gante): Remove this.

View File

@ -46,10 +46,14 @@ if is_torch_available():
Cache, Cache,
ClvpForCausalLM, ClvpForCausalLM,
DynamicCache, DynamicCache,
Gemma2Config,
GenerationConfig, GenerationConfig,
HybridCache,
LlamaConfig, LlamaConfig,
SlidingWindowCache,
StaticCache, StaticCache,
convert_and_export_with_cache, convert_and_export_with_cache,
pipeline,
) )
@ -188,6 +192,21 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128))
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
# Installed dependencies
if cache_implementation == "quantized" and not is_optimum_quanto_available():
test.skipTest("Quanto is not available")
# Devices
if "offloaded" in cache_implementation:
has_accelerator = torch_device is not None and torch_device != "cpu"
if not has_accelerator:
test.skipTest("Offloaded caches require an accelerator")
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
if backend_device_count(torch_device) != 1:
test.skipTest("Offloaded static caches require exactly 1 accelerator")
class CacheIntegrationTest(unittest.TestCase): class CacheIntegrationTest(unittest.TestCase):
"""Fast cache integration tests that share the same small model""" """Fast cache integration tests that share the same small model"""
@ -200,24 +219,10 @@ class CacheIntegrationTest(unittest.TestCase):
) )
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
def _skip_on_failed_cache_prerequisites(self, cache_implementation):
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
# Installed dependencies
if cache_implementation == "quantized" and not is_optimum_quanto_available():
self.skipTest("Quanto is not available")
# Devices
if "offloaded" in cache_implementation:
has_accelerator = torch_device is not None and torch_device != "cpu"
if not has_accelerator:
self.skipTest("Offloaded caches require an accelerator")
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
if backend_device_count(torch_device) != 1:
self.skipTest("Offloaded static caches require exactly 1 accelerator")
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_batched(self, cache_implementation): def test_cache_batched(self, cache_implementation):
"""Sanity check: caches' `.update` function expects batched inputs""" """Sanity check: caches' `.update` function expects batched inputs"""
self._skip_on_failed_cache_prerequisites(cache_implementation) _skip_on_failed_cache_prerequisites(self, cache_implementation)
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
@ -246,7 +251,7 @@ class CacheIntegrationTest(unittest.TestCase):
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
(an output sequence contains multiple beam indices). (an output sequence contains multiple beam indices).
""" """
self._skip_on_failed_cache_prerequisites(cache_implementation) _skip_on_failed_cache_prerequisites(self, cache_implementation)
if cache_implementation == "offloaded_hybrid_chunked": if cache_implementation == "offloaded_hybrid_chunked":
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly # output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
@ -280,7 +285,7 @@ class CacheIntegrationTest(unittest.TestCase):
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_extra_left_padding(self, cache_implementation): def test_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the cache""" """Tests that adding extra left-padding does not affect the generation with the cache"""
self._skip_on_failed_cache_prerequisites(cache_implementation) _skip_on_failed_cache_prerequisites(self, cache_implementation)
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."] EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
@ -552,6 +557,28 @@ class CacheHardIntegrationTest(unittest.TestCase):
_ = model(**inputs) _ = model(**inputs)
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid") _ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
@require_torch_gpu
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_gptj_model(self, cache_implementation):
"""Tests caches with GPT-J model. Regression test for https://github.com/huggingface/transformers/pull/34799"""
_skip_on_failed_cache_prerequisites(self, cache_implementation)
model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM"
pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16)
pipe.model.config.sliding_window = (
256 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None
)
out = pipe(
"hello world",
cache_implementation=cache_implementation,
max_new_tokens=10,
do_sample=False,
disable_compile=True,
return_tensors=True,
)[0]["generated_token_ids"][-10:]
EXPECTED_OUTPUT = [879, 175, 39, 141, 1000, 975, 951, 991, 683, 441]
self.assertListEqual(out, EXPECTED_OUTPUT)
@require_torch @require_torch
class CacheExportIntegrationTest(unittest.TestCase): class CacheExportIntegrationTest(unittest.TestCase):
@ -721,3 +748,276 @@ class CacheExportIntegrationTest(unittest.TestCase):
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
strict=False, strict=False,
) )
class SyntheticCacheTest(unittest.TestCase):
"""Tests cache behavior with simple dummy data."""
def setUp(self):
"""Set up common configuration and cache instances for all tests."""
self.window_size = 4
self.max_cache_len = 4
self.config = Gemma2Config(
num_hidden_layers=1,
num_key_value_heads=1,
num_attention_heads=1,
head_dim=1,
hidden_size=1,
sliding_window=self.window_size,
sliding_window_pattern=2, # Default pattern for hybrid sliding
)
def test_static_cache_out_of_bounds(self):
"""Test StaticCache raises IndexError for out-of-bounds positions."""
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len
with self.assertRaises(IndexError):
static_cache.update(
key_states=torch.tensor([[[[1.0]]]]),
value_states=torch.tensor([[[[1.0]]]]),
layer_idx=0,
cache_kwargs={"cache_position": pos_out_of_bounds},
)
def test_static_cache(self):
"""Test StaticCache with manually prefilled states and hardcoded assertions.
Scenario 1: Fill up to near capacity
prefill: [1.0, 2.0, 0.0, 0.0]
update pos 2: [1.0, 2.0, 3.0, 0.0]
Scenario 2: Fill to capacity
update pos 3: [1.0, 2.0, 3.0, 4.0]
"""
# Scenario 1: Fill up to near capacity
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None)
static_cache.update(
key_states=torch.tensor(3.0)[None, None, None, None],
value_states=torch.tensor(3.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
)
# Scenario 2: Fill to capacity
static_cache.update(
key_states=torch.tensor(4.0)[None, None, None, None],
value_states=torch.tensor(4.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
)
def test_sliding_window_cache(self):
"""Test SlidingWindowCache with manually prefilled states and hardcoded assertions.
Scenario 1: Update within window, no slide yet
prefill: [1.0, 2.0, 0.0, 0.0]
update pos 2: [1.0, 2.0, 3.0, 0.0]
Scenario 2: Update causing slide
prefill: [1.0, 2.0, 3.0, 4.0]
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
Scenario 3: Long prompt handling (prompt_len > window_size)
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
# Scenario 1: Update within window, no slide yet
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
)
sliding_cache.update(
key_states=torch.tensor(3.0)[None, None, None, None],
value_states=torch.tensor(3.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"SlidingWindowCache Scenario 1 failed",
)
# Scenario 2: Update causing slide
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
)
sliding_cache.update(
key_states=torch.tensor(5.0)[None, None, None, None],
value_states=torch.tensor(5.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"SlidingWindowCache Scenario 2 failed",
)
# Scenario 3: Long prompt handling
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
sliding_cache.update(
key_states=long_prefill,
value_states=long_prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"SlidingWindowCache Scenario 3 failed",
)
def test_hybrid_cache_static_mode(self):
"""Test HybridCache in static mode with hardcoded assertions.
Scenario 1: Static layer behavior
prefill: [1.0, 2.0, 0.0, 0.0]
update pos 2: [1.0, 2.0, 3.0, 0.0]
Scenario 2: Fill to capacity
update pos 3: [1.0, 2.0, 3.0, 4.0]
"""
config = copy.deepcopy(self.config)
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
# Scenario 1
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
hybrid_cache_static_mode.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4)},
)
hybrid_cache_static_mode.update(
key_states=torch.tensor(3.0)[None, None, None, None],
value_states=torch.tensor(3.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Static Scenario 1 failed",
)
# Scenario 2
hybrid_cache_static_mode.update(
key_states=torch.tensor(4.0)[None, None, None, None],
value_states=torch.tensor(4.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 4.0],
"HybridCache Static Scenario 2 failed",
)
def test_hybrid_cache_sliding_mode(self):
"""Test HybridCache in sliding mode with hardcoded assertions.
Scenario 1: Update within window, no slide yet
prefill: [1.0, 2.0, 0.0, 0.0]
update pos 2: [1.0, 2.0, 3.0, 0.0]
Scenario 2: Update causing first slide
prefill: [1.0, 2.0, 3.0, 4.0]
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
Scenario 3: Update causing subsequent slide
update pos 5: [3.0, 4.0, 5.0, 6.0] (shift continues)
Scenario 4: Long prompt handling (prompt_len > window_size)
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
# Scenario 1: Update within window, no slide yet
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
)
hybrid_cache.update(
key_states=torch.tensor(3.0)[None, None, None, None],
value_states=torch.tensor(3.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Sliding Scenario 1 failed",
)
# Scenario 2: Update causing first slide
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
)
hybrid_cache.update(
key_states=torch.tensor(5.0)[None, None, None, None],
value_states=torch.tensor(5.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"HybridCache Sliding Scenario 2 failed",
)
# Scenario 3: Update causing subsequent slide
hybrid_cache.update(
key_states=torch.tensor(6.0)[None, None, None, None],
value_states=torch.tensor(6.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 3 failed",
)
# Scenario 4: Long prompt handling
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
hybrid_cache.update(
key_states=long_prefill,
value_states=long_prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 4 failed",
)