diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 453d4a44bbd..513e7579f8f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -651,9 +651,7 @@ class DynamicCache(Cache): # Utilities for `DynamicCache` <> torch.export support -def _flatten_dynamic_cache( - dynamic_cache: DynamicCache, -): +def _flatten_dynamic_cache(dynamic_cache: DynamicCache): """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" if not isinstance(dynamic_cache, DynamicCache): raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") @@ -682,7 +680,7 @@ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): def _unflatten_dynamic_cache( values, context: torch.utils._pytree.Context, -): +) -> DynamicCache: dictionary = torch.utils._pytree._dict_unflatten(values, context) cache = DynamicCache() for k, v in dictionary.items(): @@ -698,6 +696,43 @@ def _flatten_dynamic_cache_for_fx(cache, spec): return torch.utils._pytree.tree_flatten(dictionary)[0] +# Utilities for `StaticCache` <> torch.export support +def _flatten_static_cache(cache: StaticCache): + """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" + if not isinstance(cache, StaticCache): + raise RuntimeError("This pytree flattening function should only be applied to StaticCache") + + if not is_torch_greater_or_equal_than_2_7: + logger.warning_once( + "StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions." + ) + + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def _flatten_with_keys_static_cache(cache: StaticCache): + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def _unflatten_static_cache( + values, + context: torch.utils._pytree.Context, +) -> StaticCache: + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = StaticCache() + for k, v in dictionary.items(): + setattr(cache, k, v) + return cache + + if is_torch_greater_or_equal("2.3"): torch.utils._pytree.register_pytree_node( DynamicCache, @@ -709,6 +744,15 @@ if is_torch_greater_or_equal("2.3"): # TODO (tmanlaibaatar) This won't be needed in torch 2.7. torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) +if is_torch_greater_or_equal("2.7"): + torch.utils._pytree.register_pytree_node( + StaticCache, + _flatten_static_cache, + _unflatten_static_cache, + serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_cache, + ) + class OffloadedCache(DynamicCache): """ diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 150bd9e1b2b..f12fb925da9 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -194,6 +194,40 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128)) + def test_unflatten_flatten_static_cache(self): + + def make_static_cache(key_value_pairs): + class _config: + def __init__(self): + self.head_dim = key_value_pairs[0][0].shape[-1] + self.num_attention_heads = key_value_pairs[0][0].shape[1] + self.num_hidden_layers = len(key_value_pairs) + + cache = transformers.cache_utils.StaticCache( + _config(), + max_batch_size=key_value_pairs[0][0].shape[0], + device=key_value_pairs[0][0].device, + dtype=key_value_pairs[0][0].dtype, + max_cache_len=key_value_pairs[0][0].shape[2], + ) + for i in range(len(key_value_pairs)): + cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] + cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] + return cache + + cache = make_static_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ) + flat, _spec = torch.utils._pytree.tree_flatten(cache) + self.assertIsInstance(flat, list) + self.assertEqual(len(flat), 6) + cache2 = torch.utils._pytree.tree_unflatten(flat, spec) + self.assertTrue(isinstance(cache2, StaticCache)) + def _skip_on_failed_cache_prerequisites(test, cache_implementation): """Function to skip tests on failed cache prerequisites, given a cache implementation""" @@ -630,6 +664,8 @@ class CacheExportIntegrationTest(unittest.TestCase): """ Tests that static cache works with `torch.export()` """ + # TODO: Another test should be implemented to follow the same pattern + # as the one implemented for DynamicCache. if not is_torch_greater_or_equal("2.3"): self.skipTest(reason="This test requires torch >= 2.3 to run.")