Add serialization function for StaticCache

This commit is contained in:
xadupre 2025-06-18 08:19:00 +02:00
parent e5a9ce48f7
commit a81079c687
2 changed files with 84 additions and 4 deletions

View File

@ -651,9 +651,7 @@ class DynamicCache(Cache):
# Utilities for `DynamicCache` <> torch.export support
def _flatten_dynamic_cache(
dynamic_cache: DynamicCache,
):
def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
if not isinstance(dynamic_cache, DynamicCache):
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
@ -682,7 +680,7 @@ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
def _unflatten_dynamic_cache(
values,
context: torch.utils._pytree.Context,
):
) -> DynamicCache:
dictionary = torch.utils._pytree._dict_unflatten(values, context)
cache = DynamicCache()
for k, v in dictionary.items():
@ -698,6 +696,43 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
return torch.utils._pytree.tree_flatten(dictionary)[0]
# Utilities for `StaticCache` <> torch.export support
def _flatten_static_cache(cache: StaticCache):
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
if not isinstance(cache, StaticCache):
raise RuntimeError("This pytree flattening function should only be applied to StaticCache")
if not is_torch_greater_or_equal_than_2_7:
logger.warning_once(
"StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions."
)
dictionary = {
"key_cache": getattr(cache, "key_cache"),
"value_cache": getattr(cache, "value_cache"),
}
return torch.utils._pytree._dict_flatten(dictionary)
def _flatten_with_keys_static_cache(cache: StaticCache):
dictionary = {
"key_cache": getattr(cache, "key_cache"),
"value_cache": getattr(cache, "value_cache"),
}
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
def _unflatten_static_cache(
values,
context: torch.utils._pytree.Context,
) -> StaticCache:
dictionary = torch.utils._pytree._dict_unflatten(values, context)
cache = StaticCache()
for k, v in dictionary.items():
setattr(cache, k, v)
return cache
if is_torch_greater_or_equal("2.3"):
torch.utils._pytree.register_pytree_node(
DynamicCache,
@ -709,6 +744,15 @@ if is_torch_greater_or_equal("2.3"):
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
if is_torch_greater_or_equal("2.7"):
torch.utils._pytree.register_pytree_node(
StaticCache,
_flatten_static_cache,
_unflatten_static_cache,
serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}",
flatten_with_keys_fn=_flatten_with_keys_static_cache,
)
class OffloadedCache(DynamicCache):
"""

View File

@ -194,6 +194,40 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
def test_unflatten_flatten_static_cache(self):
def make_static_cache(key_value_pairs):
class _config:
def __init__(self):
self.head_dim = key_value_pairs[0][0].shape[-1]
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)
cache = transformers.cache_utils.StaticCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
max_cache_len=key_value_pairs[0][0].shape[2],
)
for i in range(len(key_value_pairs)):
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
return cache
cache = make_static_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
)
flat, _spec = torch.utils._pytree.tree_flatten(cache)
self.assertIsInstance(flat, list)
self.assertEqual(len(flat), 6)
cache2 = torch.utils._pytree.tree_unflatten(flat, spec)
self.assertTrue(isinstance(cache2, StaticCache))
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
@ -630,6 +664,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
"""
Tests that static cache works with `torch.export()`
"""
# TODO: Another test should be implemented to follow the same pattern
# as the one implemented for DynamicCache.
if not is_torch_greater_or_equal("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")