From a81079c687863df51d4e05aac15cdaa6699a117d Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 18 Jun 2025 08:19:00 +0200 Subject: [PATCH 1/4] Add serialization function for StaticCache --- src/transformers/cache_utils.py | 52 ++++++++++++++++++++++++++++++--- tests/utils/test_cache_utils.py | 36 +++++++++++++++++++++++ 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 453d4a44bbd..513e7579f8f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -651,9 +651,7 @@ class DynamicCache(Cache): # Utilities for `DynamicCache` <> torch.export support -def _flatten_dynamic_cache( - dynamic_cache: DynamicCache, -): +def _flatten_dynamic_cache(dynamic_cache: DynamicCache): """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" if not isinstance(dynamic_cache, DynamicCache): raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") @@ -682,7 +680,7 @@ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): def _unflatten_dynamic_cache( values, context: torch.utils._pytree.Context, -): +) -> DynamicCache: dictionary = torch.utils._pytree._dict_unflatten(values, context) cache = DynamicCache() for k, v in dictionary.items(): @@ -698,6 +696,43 @@ def _flatten_dynamic_cache_for_fx(cache, spec): return torch.utils._pytree.tree_flatten(dictionary)[0] +# Utilities for `StaticCache` <> torch.export support +def _flatten_static_cache(cache: StaticCache): + """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" + if not isinstance(cache, StaticCache): + raise RuntimeError("This pytree flattening function should only be applied to StaticCache") + + if not is_torch_greater_or_equal_than_2_7: + logger.warning_once( + "StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions." + ) + + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def _flatten_with_keys_static_cache(cache: StaticCache): + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def _unflatten_static_cache( + values, + context: torch.utils._pytree.Context, +) -> StaticCache: + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = StaticCache() + for k, v in dictionary.items(): + setattr(cache, k, v) + return cache + + if is_torch_greater_or_equal("2.3"): torch.utils._pytree.register_pytree_node( DynamicCache, @@ -709,6 +744,15 @@ if is_torch_greater_or_equal("2.3"): # TODO (tmanlaibaatar) This won't be needed in torch 2.7. torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) +if is_torch_greater_or_equal("2.7"): + torch.utils._pytree.register_pytree_node( + StaticCache, + _flatten_static_cache, + _unflatten_static_cache, + serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_cache, + ) + class OffloadedCache(DynamicCache): """ diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 150bd9e1b2b..f12fb925da9 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -194,6 +194,40 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128)) + def test_unflatten_flatten_static_cache(self): + + def make_static_cache(key_value_pairs): + class _config: + def __init__(self): + self.head_dim = key_value_pairs[0][0].shape[-1] + self.num_attention_heads = key_value_pairs[0][0].shape[1] + self.num_hidden_layers = len(key_value_pairs) + + cache = transformers.cache_utils.StaticCache( + _config(), + max_batch_size=key_value_pairs[0][0].shape[0], + device=key_value_pairs[0][0].device, + dtype=key_value_pairs[0][0].dtype, + max_cache_len=key_value_pairs[0][0].shape[2], + ) + for i in range(len(key_value_pairs)): + cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] + cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] + return cache + + cache = make_static_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ) + flat, _spec = torch.utils._pytree.tree_flatten(cache) + self.assertIsInstance(flat, list) + self.assertEqual(len(flat), 6) + cache2 = torch.utils._pytree.tree_unflatten(flat, spec) + self.assertTrue(isinstance(cache2, StaticCache)) + def _skip_on_failed_cache_prerequisites(test, cache_implementation): """Function to skip tests on failed cache prerequisites, given a cache implementation""" @@ -630,6 +664,8 @@ class CacheExportIntegrationTest(unittest.TestCase): """ Tests that static cache works with `torch.export()` """ + # TODO: Another test should be implemented to follow the same pattern + # as the one implemented for DynamicCache. if not is_torch_greater_or_equal("2.3"): self.skipTest(reason="This test requires torch >= 2.3 to run.") From c5f7c8e44f2260fe7826a22168ead1e55f43ca07 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 18 Jun 2025 08:45:58 +0200 Subject: [PATCH 2/4] move code --- src/transformers/cache_utils.py | 95 +++++++++++++++++---------------- tests/utils/test_cache_utils.py | 4 +- 2 files changed, 50 insertions(+), 49 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 513e7579f8f..44b419a47bf 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from packaging import version -from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7 from .configuration_utils import PretrainedConfig from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging @@ -696,43 +696,6 @@ def _flatten_dynamic_cache_for_fx(cache, spec): return torch.utils._pytree.tree_flatten(dictionary)[0] -# Utilities for `StaticCache` <> torch.export support -def _flatten_static_cache(cache: StaticCache): - """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" - if not isinstance(cache, StaticCache): - raise RuntimeError("This pytree flattening function should only be applied to StaticCache") - - if not is_torch_greater_or_equal_than_2_7: - logger.warning_once( - "StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions." - ) - - dictionary = { - "key_cache": getattr(cache, "key_cache"), - "value_cache": getattr(cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten(dictionary) - - -def _flatten_with_keys_static_cache(cache: StaticCache): - dictionary = { - "key_cache": getattr(cache, "key_cache"), - "value_cache": getattr(cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) - - -def _unflatten_static_cache( - values, - context: torch.utils._pytree.Context, -) -> StaticCache: - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = StaticCache() - for k, v in dictionary.items(): - setattr(cache, k, v) - return cache - - if is_torch_greater_or_equal("2.3"): torch.utils._pytree.register_pytree_node( DynamicCache, @@ -744,15 +707,6 @@ if is_torch_greater_or_equal("2.3"): # TODO (tmanlaibaatar) This won't be needed in torch 2.7. torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) -if is_torch_greater_or_equal("2.7"): - torch.utils._pytree.register_pytree_node( - StaticCache, - _flatten_static_cache, - _unflatten_static_cache, - serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}", - flatten_with_keys_fn=_flatten_with_keys_static_cache, - ) - class OffloadedCache(DynamicCache): """ @@ -1274,6 +1228,53 @@ class StaticCache(Cache): return kv_length, 0 +# Utilities for `StaticCache` <> torch.export support +def _flatten_static_cache(cache: StaticCache): + """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" + if not isinstance(cache, StaticCache): + raise RuntimeError("This pytree flattening function should only be applied to StaticCache") + + if not is_torch_greater_or_equal_than_2_7: + logger.warning_once( + "StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions." + ) + + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def _flatten_with_keys_static_cache(cache: StaticCache): + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def _unflatten_static_cache( + values, + context: torch.utils._pytree.Context, +) -> StaticCache: + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = StaticCache() + for k, v in dictionary.items(): + setattr(cache, k, v) + return cache + + +if is_torch_greater_or_equal("2.7"): + torch.utils._pytree.register_pytree_node( + StaticCache, + _flatten_static_cache, + _unflatten_static_cache, + serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_cache, + ) + + class SlidingWindowCache(StaticCache): """ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index f12fb925da9..f778f016295 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -203,7 +203,7 @@ class CacheTest(unittest.TestCase): self.num_attention_heads = key_value_pairs[0][0].shape[1] self.num_hidden_layers = len(key_value_pairs) - cache = transformers.cache_utils.StaticCache( + cache = StaticCache( _config(), max_batch_size=key_value_pairs[0][0].shape[0], device=key_value_pairs[0][0].device, @@ -222,7 +222,7 @@ class CacheTest(unittest.TestCase): (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), ] ) - flat, _spec = torch.utils._pytree.tree_flatten(cache) + flat, spec = torch.utils._pytree.tree_flatten(cache) self.assertIsInstance(flat, list) self.assertEqual(len(flat), 6) cache2 = torch.utils._pytree.tree_unflatten(flat, spec) From 11d2d67dbaec68eeea2e19b348245eead0b75e9a Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 18 Jun 2025 09:00:14 +0200 Subject: [PATCH 3/4] fix test --- src/transformers/cache_utils.py | 26 +++++++++++++++++++++----- src/transformers/pytorch_utils.py | 1 + 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 44b419a47bf..6d4ef1d527b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1254,15 +1254,31 @@ def _flatten_with_keys_static_cache(cache: StaticCache): return torch.utils._pytree._dict_flatten_with_keys(dictionary) +def _make_static_cache(key_value_pairs): + class _config: + def __init__(self): + self.head_dim = key_value_pairs[0][0].shape[-1] + self.num_attention_heads = key_value_pairs[0][0].shape[1] + self.num_hidden_layers = len(key_value_pairs) + + cache = StaticCache( + _config(), + max_batch_size=key_value_pairs[0][0].shape[0], + device=key_value_pairs[0][0].device, + dtype=key_value_pairs[0][0].dtype, + max_cache_len=key_value_pairs[0][0].shape[2], + ) + for i in range(len(key_value_pairs)): + cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] + cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] + return cache + + def _unflatten_static_cache( values, context: torch.utils._pytree.Context, ) -> StaticCache: - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = StaticCache() - for k, v in dictionary.items(): - setattr(cache, k, v) - return cache + return _make_static_cache(list(zip(values[0], values[1]))) if is_torch_greater_or_equal("2.7"): diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 9bb02bff963..7de581cee26 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -28,6 +28,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm] logger = logging.get_logger(__name__) +is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True) is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True) is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True) From 2091dad578f4b643031742a8e5226f03a107aa25 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 18 Jun 2025 09:04:18 +0200 Subject: [PATCH 4/4] ruff --- tests/utils/test_cache_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index f778f016295..e9ac3e12fb8 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -195,7 +195,6 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_values.shape == (1, 1, 10, 128)) def test_unflatten_flatten_static_cache(self): - def make_static_cache(key_value_pairs): class _config: def __init__(self):