diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 04ccc6f7efc..5cc544f456e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -9,7 +9,7 @@ from typing import Any, Optional, Union import torch from packaging import version -from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7 from .configuration_utils import PretrainedConfig from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging @@ -651,9 +651,7 @@ class DynamicCache(Cache): # Utilities for `DynamicCache` <> torch.export support -def _flatten_dynamic_cache( - dynamic_cache: DynamicCache, -): +def _flatten_dynamic_cache(dynamic_cache: DynamicCache): """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" if not isinstance(dynamic_cache, DynamicCache): raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") @@ -682,7 +680,7 @@ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): def _unflatten_dynamic_cache( values, context: torch.utils._pytree.Context, -): +) -> DynamicCache: dictionary = torch.utils._pytree._dict_unflatten(values, context) cache = DynamicCache() for k, v in dictionary.items(): @@ -1230,6 +1228,69 @@ class StaticCache(Cache): return kv_length, 0 +# Utilities for `StaticCache` <> torch.export support +def _flatten_static_cache(cache: StaticCache): + """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" + if not isinstance(cache, StaticCache): + raise RuntimeError("This pytree flattening function should only be applied to StaticCache") + + if not is_torch_greater_or_equal_than_2_7: + logger.warning_once( + "StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions." + ) + + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def _flatten_with_keys_static_cache(cache: StaticCache): + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def _make_static_cache(key_value_pairs): + class _config: + def __init__(self): + self.head_dim = key_value_pairs[0][0].shape[-1] + self.num_attention_heads = key_value_pairs[0][0].shape[1] + self.num_hidden_layers = len(key_value_pairs) + + cache = StaticCache( + _config(), + max_batch_size=key_value_pairs[0][0].shape[0], + device=key_value_pairs[0][0].device, + dtype=key_value_pairs[0][0].dtype, + max_cache_len=key_value_pairs[0][0].shape[2], + ) + for i in range(len(key_value_pairs)): + cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] + cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] + return cache + + +def _unflatten_static_cache( + values, + context: torch.utils._pytree.Context, +) -> StaticCache: + return _make_static_cache(list(zip(values[0], values[1]))) + + +if is_torch_greater_or_equal("2.7"): + torch.utils._pytree.register_pytree_node( + StaticCache, + _flatten_static_cache, + _unflatten_static_cache, + serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_static_cache, + ) + + class SlidingWindowCache(StaticCache): """ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index e4bd74720a8..2de1dd8c41a 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -28,6 +28,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm] logger = logging.get_logger(__name__) +is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True) is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True) is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 8c864f9b64f..2534ef7da99 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -194,6 +194,39 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128)) + def test_unflatten_flatten_static_cache(self): + def make_static_cache(key_value_pairs): + class _config: + def __init__(self): + self.head_dim = key_value_pairs[0][0].shape[-1] + self.num_attention_heads = key_value_pairs[0][0].shape[1] + self.num_hidden_layers = len(key_value_pairs) + + cache = StaticCache( + _config(), + max_batch_size=key_value_pairs[0][0].shape[0], + device=key_value_pairs[0][0].device, + dtype=key_value_pairs[0][0].dtype, + max_cache_len=key_value_pairs[0][0].shape[2], + ) + for i in range(len(key_value_pairs)): + cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] + cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] + return cache + + cache = make_static_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ) + flat, spec = torch.utils._pytree.tree_flatten(cache) + self.assertIsInstance(flat, list) + self.assertEqual(len(flat), 6) + cache2 = torch.utils._pytree.tree_unflatten(flat, spec) + self.assertTrue(isinstance(cache2, StaticCache)) + def _skip_on_failed_cache_prerequisites(test, cache_implementation): """Function to skip tests on failed cache prerequisites, given a cache implementation""" @@ -726,6 +759,8 @@ class CacheExportIntegrationTest(unittest.TestCase): """ Tests that static cache works with `torch.export()` """ + # TODO: Another test should be implemented to follow the same pattern + # as the one implemented for DynamicCache. if not is_torch_greater_or_equal("2.3"): self.skipTest(reason="This test requires torch >= 2.3 to run.")