This commit is contained in:
Xavier Dupré 2025-07-02 12:39:37 -07:00 committed by GitHub
commit d13a543f94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 102 additions and 5 deletions

View File

@ -9,7 +9,7 @@ from typing import Any, Optional, Union
import torch
from packaging import version
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7
from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
@ -651,9 +651,7 @@ class DynamicCache(Cache):
# Utilities for `DynamicCache` <> torch.export support
def _flatten_dynamic_cache(
dynamic_cache: DynamicCache,
):
def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
if not isinstance(dynamic_cache, DynamicCache):
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
@ -682,7 +680,7 @@ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
def _unflatten_dynamic_cache(
values,
context: torch.utils._pytree.Context,
):
) -> DynamicCache:
dictionary = torch.utils._pytree._dict_unflatten(values, context)
cache = DynamicCache()
for k, v in dictionary.items():
@ -1230,6 +1228,69 @@ class StaticCache(Cache):
return kv_length, 0
# Utilities for `StaticCache` <> torch.export support
def _flatten_static_cache(cache: StaticCache):
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
if not isinstance(cache, StaticCache):
raise RuntimeError("This pytree flattening function should only be applied to StaticCache")
if not is_torch_greater_or_equal_than_2_7:
logger.warning_once(
"StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions."
)
dictionary = {
"key_cache": getattr(cache, "key_cache"),
"value_cache": getattr(cache, "value_cache"),
}
return torch.utils._pytree._dict_flatten(dictionary)
def _flatten_with_keys_static_cache(cache: StaticCache):
dictionary = {
"key_cache": getattr(cache, "key_cache"),
"value_cache": getattr(cache, "value_cache"),
}
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
def _make_static_cache(key_value_pairs):
class _config:
def __init__(self):
self.head_dim = key_value_pairs[0][0].shape[-1]
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)
cache = StaticCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
max_cache_len=key_value_pairs[0][0].shape[2],
)
for i in range(len(key_value_pairs)):
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
return cache
def _unflatten_static_cache(
values,
context: torch.utils._pytree.Context,
) -> StaticCache:
return _make_static_cache(list(zip(values[0], values[1])))
if is_torch_greater_or_equal("2.7"):
torch.utils._pytree.register_pytree_node(
StaticCache,
_flatten_static_cache,
_unflatten_static_cache,
serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}",
flatten_with_keys_fn=_flatten_with_keys_static_cache,
)
class SlidingWindowCache(StaticCache):
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.

View File

@ -28,6 +28,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
logger = logging.get_logger(__name__)
is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)

View File

@ -194,6 +194,39 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
def test_unflatten_flatten_static_cache(self):
def make_static_cache(key_value_pairs):
class _config:
def __init__(self):
self.head_dim = key_value_pairs[0][0].shape[-1]
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)
cache = StaticCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
max_cache_len=key_value_pairs[0][0].shape[2],
)
for i in range(len(key_value_pairs)):
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
return cache
cache = make_static_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
)
flat, spec = torch.utils._pytree.tree_flatten(cache)
self.assertIsInstance(flat, list)
self.assertEqual(len(flat), 6)
cache2 = torch.utils._pytree.tree_unflatten(flat, spec)
self.assertTrue(isinstance(cache2, StaticCache))
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
@ -726,6 +759,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
"""
Tests that static cache works with `torch.export()`
"""
# TODO: Another test should be implemented to follow the same pattern
# as the one implemented for DynamicCache.
if not is_torch_greater_or_equal("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")