mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add serialization function for StaticCache
This commit is contained in:
parent
e5a9ce48f7
commit
a81079c687
@ -651,9 +651,7 @@ class DynamicCache(Cache):
|
||||
|
||||
|
||||
# Utilities for `DynamicCache` <> torch.export support
|
||||
def _flatten_dynamic_cache(
|
||||
dynamic_cache: DynamicCache,
|
||||
):
|
||||
def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
|
||||
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
|
||||
if not isinstance(dynamic_cache, DynamicCache):
|
||||
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
|
||||
@ -682,7 +680,7 @@ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
|
||||
def _unflatten_dynamic_cache(
|
||||
values,
|
||||
context: torch.utils._pytree.Context,
|
||||
):
|
||||
) -> DynamicCache:
|
||||
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
||||
cache = DynamicCache()
|
||||
for k, v in dictionary.items():
|
||||
@ -698,6 +696,43 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
|
||||
return torch.utils._pytree.tree_flatten(dictionary)[0]
|
||||
|
||||
|
||||
# Utilities for `StaticCache` <> torch.export support
|
||||
def _flatten_static_cache(cache: StaticCache):
|
||||
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
|
||||
if not isinstance(cache, StaticCache):
|
||||
raise RuntimeError("This pytree flattening function should only be applied to StaticCache")
|
||||
|
||||
if not is_torch_greater_or_equal_than_2_7:
|
||||
logger.warning_once(
|
||||
"StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions."
|
||||
)
|
||||
|
||||
dictionary = {
|
||||
"key_cache": getattr(cache, "key_cache"),
|
||||
"value_cache": getattr(cache, "value_cache"),
|
||||
}
|
||||
return torch.utils._pytree._dict_flatten(dictionary)
|
||||
|
||||
|
||||
def _flatten_with_keys_static_cache(cache: StaticCache):
|
||||
dictionary = {
|
||||
"key_cache": getattr(cache, "key_cache"),
|
||||
"value_cache": getattr(cache, "value_cache"),
|
||||
}
|
||||
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
||||
|
||||
|
||||
def _unflatten_static_cache(
|
||||
values,
|
||||
context: torch.utils._pytree.Context,
|
||||
) -> StaticCache:
|
||||
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
||||
cache = StaticCache()
|
||||
for k, v in dictionary.items():
|
||||
setattr(cache, k, v)
|
||||
return cache
|
||||
|
||||
|
||||
if is_torch_greater_or_equal("2.3"):
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
DynamicCache,
|
||||
@ -709,6 +744,15 @@ if is_torch_greater_or_equal("2.3"):
|
||||
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
|
||||
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
|
||||
|
||||
if is_torch_greater_or_equal("2.7"):
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
StaticCache,
|
||||
_flatten_static_cache,
|
||||
_unflatten_static_cache,
|
||||
serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}",
|
||||
flatten_with_keys_fn=_flatten_with_keys_static_cache,
|
||||
)
|
||||
|
||||
|
||||
class OffloadedCache(DynamicCache):
|
||||
"""
|
||||
|
@ -194,6 +194,40 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
||||
def test_unflatten_flatten_static_cache(self):
|
||||
|
||||
def make_static_cache(key_value_pairs):
|
||||
class _config:
|
||||
def __init__(self):
|
||||
self.head_dim = key_value_pairs[0][0].shape[-1]
|
||||
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
||||
self.num_hidden_layers = len(key_value_pairs)
|
||||
|
||||
cache = transformers.cache_utils.StaticCache(
|
||||
_config(),
|
||||
max_batch_size=key_value_pairs[0][0].shape[0],
|
||||
device=key_value_pairs[0][0].device,
|
||||
dtype=key_value_pairs[0][0].dtype,
|
||||
max_cache_len=key_value_pairs[0][0].shape[2],
|
||||
)
|
||||
for i in range(len(key_value_pairs)):
|
||||
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
||||
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
||||
return cache
|
||||
|
||||
cache = make_static_cache(
|
||||
[
|
||||
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
||||
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
||||
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
||||
]
|
||||
)
|
||||
flat, _spec = torch.utils._pytree.tree_flatten(cache)
|
||||
self.assertIsInstance(flat, list)
|
||||
self.assertEqual(len(flat), 6)
|
||||
cache2 = torch.utils._pytree.tree_unflatten(flat, spec)
|
||||
self.assertTrue(isinstance(cache2, StaticCache))
|
||||
|
||||
|
||||
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
|
||||
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
||||
@ -630,6 +664,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
"""
|
||||
Tests that static cache works with `torch.export()`
|
||||
"""
|
||||
# TODO: Another test should be implemented to follow the same pattern
|
||||
# as the one implemented for DynamicCache.
|
||||
if not is_torch_greater_or_equal("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user