From d34e21e7dd392c76e7852d836b6f30ba1a2c5d62 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Tue, 20 May 2025 12:46:13 +0200 Subject: [PATCH] New cache tests and refactored Hybrid Cache (#37972) --- src/transformers/cache_utils.py | 284 ++++++++++++++------------- tests/utils/test_cache_utils.py | 334 ++++++++++++++++++++++++++++++-- 2 files changed, 471 insertions(+), 147 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 40737bb69ad..d24edd390c2 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -21,6 +21,104 @@ if is_hqq_available(): logger = logging.get_logger(__name__) +# Utility functions for static/sliding cache update logic +def _static_cache_update( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_position: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the static cache tensors in place. + + Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. + key_states (`torch.Tensor`): The new key states to add. + value_states (`torch.Tensor`): The new value states to add. + cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. + If None, the entire cache is overwritten (prefill). + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). + """ + if cache_position is None: + # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. + k_cache.copy_(key_states) + v_cache.copy_(value_states) + else: + # Generation phase. Update specific positions. + # Use index_copy_ for in-place update (compile-friendly). + try: + k_cache.index_copy_(2, cache_position, key_states) + v_cache.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + k_cache[:, :, cache_position] = key_states + v_cache[:, :, cache_position] = value_states + return k_cache, v_cache + + +def _sliding_cache_update( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_position: torch.LongTensor, + max_cache_len: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the sliding window cache tensors, returning the potentially modified tensors. + + Args: + k_cache (`torch.Tensor`): The key cache tensor to update. + v_cache (`torch.Tensor`): The value cache tensor to update. + key_states (`torch.Tensor`): The new key states to add. + value_states (`torch.Tensor`): The new value states to add. + cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. + max_cache_len (`int`): The maximum length of the sliding window cache. + + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. + For prefill > window, these are the full input states. + Otherwise, they are the updated cache tensors. + """ + # Handle prefill phase when prompt length > sliding_window_size + if cache_position.shape[0] > max_cache_len: + new_k = key_states[:, :, -max_cache_len:, :] + new_v = value_states[:, :, -max_cache_len:, :] + k_cache.copy_(new_k) + v_cache.copy_(new_v) + return key_states, value_states + + # Sliding window logic for generation phase or prefill < window + slicing = torch.arange(max_cache_len, device=value_states.device) + current_seq_len = cache_position[-1] + 1 # Use last position to determine current length + to_shift = current_seq_len > max_cache_len + indices = (slicing + to_shift.sum()) % max_cache_len + + k_out_shifted = k_cache[:, :, indices] + v_out_shifted = v_cache[:, :, indices] + + # Clamp cache_position to determine the *target index* within the shifted cache view + update_position = cache_position.clamp(min=0, max=max_cache_len - 1) + + try: + k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) + v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) + except NotImplementedError: + # Fallback for MPS: clone and modify the clone + k_out_updated = k_out_shifted.clone() + v_out_updated = v_out_shifted.clone() + k_out_updated[:, :, update_position] = key_states + v_out_updated[:, :, update_position] = value_states + + k_cache.copy_(k_out_updated) + v_cache.copy_(v_out_updated) + return k_out_updated, v_out_updated + + class Cache: """ Base, abstract class for all caches. The actual data structure is specific to each subclass. @@ -1264,28 +1362,16 @@ class StaticCache(Cache): """ if cache_kwargs is None: cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place - # operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - return k_out, v_out + key_states = key_states.to(self.key_cache[layer_idx].dtype) + value_states = value_states.to(self.value_cache[layer_idx].dtype) + return _static_cache_update( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + key_states, + value_states, + cache_kwargs.get("cache_position"), + ) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" @@ -1314,7 +1400,7 @@ class SlidingWindowCache(StaticCache): The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window + indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, @@ -1398,46 +1484,21 @@ class SlidingWindowCache(StaticCache): if cache_kwargs is None: cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) - if cache_position.shape[0] >= self.max_cache_len: - k_out = key_states[:, :, -self.max_cache_len :, :] - v_out = value_states[:, :, -self.max_cache_len :, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states + if cache_position is None: + raise ValueError("`cache_position` must be provided for SlidingWindowCache.") - slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) - to_shift = cache_position > self.max_cache_len - 1 - cache_position = cache_position.clamp(0, self.max_cache_len - 1) - indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len + key_states = key_states.to(self.key_cache[layer_idx].dtype) + value_states = value_states.to(self.value_cache[layer_idx].dtype) - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - - return k_out, v_out + return _sliding_cache_update( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + key_states, + value_states, + cache_position, + self.max_cache_len, + ) def get_max_cache_shape(self) -> Optional[int]: return self.max_cache_len @@ -1680,12 +1741,13 @@ class HybridCache(Cache): super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "Setting `cache_implementation` to 'hybrid' requires the model config supporting " "sliding window attention, please check if there is a `sliding_window` field in the model " "config and it's not set to None." ) - self.max_cache_len = max_cache_len - self._sliding_window_max_len = min(config.sliding_window, max_cache_len) + self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings + # Sliding layers can't be larger than the overall max cache len + self.sliding_window_len = min(config.sliding_window, self.max_cache_len) self.max_batch_size = max_batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( @@ -1694,22 +1756,17 @@ class HybridCache(Cache): self._dtype = dtype self.num_key_value_heads = ( - config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads ) layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC - self.is_sliding = torch.tensor( - [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool - ) + self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)] self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) - sliding_cache_shape = ( - self.max_batch_size, - self.num_key_value_heads, - self._sliding_window_max_len, - self.head_dim, - ) + global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) + sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) device = torch.device(device) if device is not None else None for i in range(config.num_hidden_layers): if layer_device_map is not None: @@ -1718,7 +1775,7 @@ class HybridCache(Cache): layer_device = device # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. - cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape + cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) torch._dynamo.mark_static_address(new_layer_key_cache) @@ -1726,42 +1783,6 @@ class HybridCache(Cache): self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) - def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - if cache_position.shape[0] >= max_cache_len: - k_out = key_states[:, :, -max_cache_len:, :] - v_out = value_states[:, :, -max_cache_len:, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) - to_shift = cache_position > max_cache_len - 1 - cache_position = cache_position.clamp(0, max_cache_len - 1) - indices = (slicing + to_shift[-1].int() - 1) % max_cache_len - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - return k_out, v_out - - def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out - def update( self, key_states: torch.Tensor, @@ -1772,7 +1793,10 @@ class HybridCache(Cache): if cache_kwargs is None: cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") - sliding_window = cache_kwargs.get("sliding_window") + if cache_position is None: + raise ValueError("`cache_position` must be provided for HybridCache.") + + is_sliding_layer = self.is_sliding_list[layer_idx] # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # when the cache is initialized in the forward pass (e.g. Gemma2) @@ -1781,25 +1805,22 @@ class HybridCache(Cache): if self.value_cache[layer_idx].device != value_states.device: self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) + k_cache = self.key_cache[layer_idx] + v_cache = self.value_cache[layer_idx] + key_states = key_states.to(k_cache.dtype) + value_states = value_states.to(v_cache.dtype) - if sliding_window: - update_fn = self._sliding_update + if is_sliding_layer: + return _sliding_cache_update( + k_cache, + v_cache, + key_states, + value_states, + cache_position, + k_cache.shape[2], # Use actual cache dim as max cache len + ) else: - update_fn = self._static_update - - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) + return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position) def get_max_cache_shape(self) -> Optional[int]: return self.max_cache_len @@ -2033,7 +2054,7 @@ class OffloadedHybridCache(HybridChunkedCache): # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps # track of the original device of each layer - unique_devices = set(layer_device_map.values()) + unique_devices = set(layer_device_map.values()) if layer_device_map else set() if len(unique_devices) > 1: raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}") @@ -2292,7 +2313,7 @@ class OffloadedStaticCache(StaticCache): # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps # track of the original device of each layer - unique_devices = set(layer_device_map.values()) + unique_devices = set(layer_device_map.values()) if layer_device_map else set() if len(unique_devices) > 1: raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}") @@ -2369,6 +2390,9 @@ class OffloadedStaticCache(StaticCache): A tuple containing the updated key and value states. """ + key_states = key_states.to(self.key_cache[layer_idx].dtype) + value_states = value_states.to(self.value_cache[layer_idx].dtype) + if layer_idx == 0: # Update seen tokens. # TODO(gante): Remove this. diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 089b45c192b..ea56763a65f 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -46,10 +46,14 @@ if is_torch_available(): Cache, ClvpForCausalLM, DynamicCache, + Gemma2Config, GenerationConfig, + HybridCache, LlamaConfig, + SlidingWindowCache, StaticCache, convert_and_export_with_cache, + pipeline, ) @@ -188,6 +192,21 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_values.shape == (1, 1, 10, 128)) +def _skip_on_failed_cache_prerequisites(test, cache_implementation): + """Function to skip tests on failed cache prerequisites, given a cache implementation""" + # Installed dependencies + if cache_implementation == "quantized" and not is_optimum_quanto_available(): + test.skipTest("Quanto is not available") + # Devices + if "offloaded" in cache_implementation: + has_accelerator = torch_device is not None and torch_device != "cpu" + if not has_accelerator: + test.skipTest("Offloaded caches require an accelerator") + if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]: + if backend_device_count(torch_device) != 1: + test.skipTest("Offloaded static caches require exactly 1 accelerator") + + class CacheIntegrationTest(unittest.TestCase): """Fast cache integration tests that share the same small model""" @@ -200,24 +219,10 @@ class CacheIntegrationTest(unittest.TestCase): ) cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows - def _skip_on_failed_cache_prerequisites(self, cache_implementation): - """Function to skip tests on failed cache prerequisites, given a cache implementation""" - # Installed dependencies - if cache_implementation == "quantized" and not is_optimum_quanto_available(): - self.skipTest("Quanto is not available") - # Devices - if "offloaded" in cache_implementation: - has_accelerator = torch_device is not None and torch_device != "cpu" - if not has_accelerator: - self.skipTest("Offloaded caches require an accelerator") - if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]: - if backend_device_count(torch_device) != 1: - self.skipTest("Offloaded static caches require exactly 1 accelerator") - @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_batched(self, cache_implementation): """Sanity check: caches' `.update` function expects batched inputs""" - self._skip_on_failed_cache_prerequisites(cache_implementation) + _skip_on_failed_cache_prerequisites(self, cache_implementation) EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] @@ -246,7 +251,7 @@ class CacheIntegrationTest(unittest.TestCase): Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices (an output sequence contains multiple beam indices). """ - self._skip_on_failed_cache_prerequisites(cache_implementation) + _skip_on_failed_cache_prerequisites(self, cache_implementation) if cache_implementation == "offloaded_hybrid_chunked": # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the # output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly @@ -280,7 +285,7 @@ class CacheIntegrationTest(unittest.TestCase): @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_extra_left_padding(self, cache_implementation): """Tests that adding extra left-padding does not affect the generation with the cache""" - self._skip_on_failed_cache_prerequisites(cache_implementation) + _skip_on_failed_cache_prerequisites(self, cache_implementation) EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."] @@ -552,6 +557,28 @@ class CacheHardIntegrationTest(unittest.TestCase): _ = model(**inputs) _ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid") + @require_torch_gpu + @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) + def test_cache_gptj_model(self, cache_implementation): + """Tests caches with GPT-J model. Regression test for https://github.com/huggingface/transformers/pull/34799""" + _skip_on_failed_cache_prerequisites(self, cache_implementation) + + model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM" + pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16) + pipe.model.config.sliding_window = ( + 256 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None + ) + out = pipe( + "hello world", + cache_implementation=cache_implementation, + max_new_tokens=10, + do_sample=False, + disable_compile=True, + return_tensors=True, + )[0]["generated_token_ids"][-10:] + EXPECTED_OUTPUT = [879, 175, 39, 141, 1000, 975, 951, 991, 683, 441] + self.assertListEqual(out, EXPECTED_OUTPUT) + @require_torch class CacheExportIntegrationTest(unittest.TestCase): @@ -721,3 +748,276 @@ class CacheExportIntegrationTest(unittest.TestCase): dynamic_shapes=dynamic_shapes, strict=False, ) + + +class SyntheticCacheTest(unittest.TestCase): + """Tests cache behavior with simple dummy data.""" + + def setUp(self): + """Set up common configuration and cache instances for all tests.""" + self.window_size = 4 + self.max_cache_len = 4 + self.config = Gemma2Config( + num_hidden_layers=1, + num_key_value_heads=1, + num_attention_heads=1, + head_dim=1, + hidden_size=1, + sliding_window=self.window_size, + sliding_window_pattern=2, # Default pattern for hybrid sliding + ) + + def test_static_cache_out_of_bounds(self): + """Test StaticCache raises IndexError for out-of-bounds positions.""" + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len + + with self.assertRaises(IndexError): + static_cache.update( + key_states=torch.tensor([[[[1.0]]]]), + value_states=torch.tensor([[[[1.0]]]]), + layer_idx=0, + cache_kwargs={"cache_position": pos_out_of_bounds}, + ) + + def test_static_cache(self): + """Test StaticCache with manually prefilled states and hardcoded assertions. + + Scenario 1: Fill up to near capacity + prefill: [1.0, 2.0, 0.0, 0.0] + update pos 2: [1.0, 2.0, 3.0, 0.0] + + Scenario 2: Fill to capacity + update pos 3: [1.0, 2.0, 3.0, 4.0] + """ + # Scenario 1: Fill up to near capacity + static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None) + static_cache.update( + key_states=torch.tensor(3.0)[None, None, None, None], + value_states=torch.tensor(3.0)[None, None, None, None], + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2])}, + ) + self.assertEqual( + static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" + ) + + # Scenario 2: Fill to capacity + static_cache.update( + key_states=torch.tensor(4.0)[None, None, None, None], + value_states=torch.tensor(4.0)[None, None, None, None], + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + self.assertEqual( + static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" + ) + + def test_sliding_window_cache(self): + """Test SlidingWindowCache with manually prefilled states and hardcoded assertions. + + Scenario 1: Update within window, no slide yet + prefill: [1.0, 2.0, 0.0, 0.0] + update pos 2: [1.0, 2.0, 3.0, 0.0] + + Scenario 2: Update causing slide + prefill: [1.0, 2.0, 3.0, 4.0] + update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1) + + Scenario 3: Long prompt handling (prompt_len > window_size) + input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) + """ + # Scenario 1: Update within window, no slide yet + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + sliding_cache.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + ) + sliding_cache.update( + key_states=torch.tensor(3.0)[None, None, None, None], + value_states=torch.tensor(3.0)[None, None, None, None], + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, + ) + self.assertEqual( + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "SlidingWindowCache Scenario 1 failed", + ) + + # Scenario 2: Update causing slide + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] + sliding_cache.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + ) + sliding_cache.update( + key_states=torch.tensor(5.0)[None, None, None, None], + value_states=torch.tensor(5.0)[None, None, None, None], + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, + ) + self.assertEqual( + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + [2.0, 3.0, 4.0, 5.0], + "SlidingWindowCache Scenario 2 failed", + ) + + # Scenario 3: Long prompt handling + sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] + sliding_cache.update( + key_states=long_prefill, + value_states=long_prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, + ) + self.assertEqual( + sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + [3.0, 4.0, 5.0, 6.0], + "SlidingWindowCache Scenario 3 failed", + ) + + def test_hybrid_cache_static_mode(self): + """Test HybridCache in static mode with hardcoded assertions. + + Scenario 1: Static layer behavior + prefill: [1.0, 2.0, 0.0, 0.0] + update pos 2: [1.0, 2.0, 3.0, 0.0] + + Scenario 2: Fill to capacity + update pos 3: [1.0, 2.0, 3.0, 4.0] + """ + config = copy.deepcopy(self.config) + config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) + + # Scenario 1 + hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + hybrid_cache_static_mode.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4)}, + ) + hybrid_cache_static_mode.update( + key_states=torch.tensor(3.0)[None, None, None, None], + value_states=torch.tensor(3.0)[None, None, None, None], + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2])}, + ) + self.assertEqual( + hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "HybridCache Static Scenario 1 failed", + ) + + # Scenario 2 + hybrid_cache_static_mode.update( + key_states=torch.tensor(4.0)[None, None, None, None], + value_states=torch.tensor(4.0)[None, None, None, None], + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + self.assertEqual( + hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 4.0], + "HybridCache Static Scenario 2 failed", + ) + + def test_hybrid_cache_sliding_mode(self): + """Test HybridCache in sliding mode with hardcoded assertions. + + Scenario 1: Update within window, no slide yet + prefill: [1.0, 2.0, 0.0, 0.0] + update pos 2: [1.0, 2.0, 3.0, 0.0] + + Scenario 2: Update causing first slide + prefill: [1.0, 2.0, 3.0, 4.0] + update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1) + + Scenario 3: Update causing subsequent slide + update pos 5: [3.0, 4.0, 5.0, 6.0] (shift continues) + + Scenario 4: Long prompt handling (prompt_len > window_size) + input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) + """ + # Scenario 1: Update within window, no slide yet + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + hybrid_cache.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + ) + hybrid_cache.update( + key_states=torch.tensor(3.0)[None, None, None, None], + value_states=torch.tensor(3.0)[None, None, None, None], + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, + ) + self.assertEqual( + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "HybridCache Sliding Scenario 1 failed", + ) + + # Scenario 2: Update causing first slide + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] + hybrid_cache.update( + key_states=prefill, + value_states=prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + ) + hybrid_cache.update( + key_states=torch.tensor(5.0)[None, None, None, None], + value_states=torch.tensor(5.0)[None, None, None, None], + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, + ) + self.assertEqual( + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [2.0, 3.0, 4.0, 5.0], + "HybridCache Sliding Scenario 2 failed", + ) + + # Scenario 3: Update causing subsequent slide + hybrid_cache.update( + key_states=torch.tensor(6.0)[None, None, None, None], + value_states=torch.tensor(6.0)[None, None, None, None], + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, + ) + self.assertEqual( + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [3.0, 4.0, 5.0, 6.0], + "HybridCache Sliding Scenario 3 failed", + ) + + # Scenario 4: Long prompt handling + hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] + hybrid_cache.update( + key_states=long_prefill, + value_states=long_prefill, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, + ) + self.assertEqual( + hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + [3.0, 4.0, 5.0, 6.0], + "HybridCache Sliding Scenario 4 failed", + )