mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge 09ef0d16ee
into 2d561713f8
This commit is contained in:
commit
d13a543f94
@ -9,7 +9,7 @@ from typing import Any, Optional, Union
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
|
||||
@ -651,9 +651,7 @@ class DynamicCache(Cache):
|
||||
|
||||
|
||||
# Utilities for `DynamicCache` <> torch.export support
|
||||
def _flatten_dynamic_cache(
|
||||
dynamic_cache: DynamicCache,
|
||||
):
|
||||
def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
|
||||
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
|
||||
if not isinstance(dynamic_cache, DynamicCache):
|
||||
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
|
||||
@ -682,7 +680,7 @@ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
|
||||
def _unflatten_dynamic_cache(
|
||||
values,
|
||||
context: torch.utils._pytree.Context,
|
||||
):
|
||||
) -> DynamicCache:
|
||||
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
||||
cache = DynamicCache()
|
||||
for k, v in dictionary.items():
|
||||
@ -1230,6 +1228,69 @@ class StaticCache(Cache):
|
||||
return kv_length, 0
|
||||
|
||||
|
||||
# Utilities for `StaticCache` <> torch.export support
|
||||
def _flatten_static_cache(cache: StaticCache):
|
||||
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
|
||||
if not isinstance(cache, StaticCache):
|
||||
raise RuntimeError("This pytree flattening function should only be applied to StaticCache")
|
||||
|
||||
if not is_torch_greater_or_equal_than_2_7:
|
||||
logger.warning_once(
|
||||
"StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions."
|
||||
)
|
||||
|
||||
dictionary = {
|
||||
"key_cache": getattr(cache, "key_cache"),
|
||||
"value_cache": getattr(cache, "value_cache"),
|
||||
}
|
||||
return torch.utils._pytree._dict_flatten(dictionary)
|
||||
|
||||
|
||||
def _flatten_with_keys_static_cache(cache: StaticCache):
|
||||
dictionary = {
|
||||
"key_cache": getattr(cache, "key_cache"),
|
||||
"value_cache": getattr(cache, "value_cache"),
|
||||
}
|
||||
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
||||
|
||||
|
||||
def _make_static_cache(key_value_pairs):
|
||||
class _config:
|
||||
def __init__(self):
|
||||
self.head_dim = key_value_pairs[0][0].shape[-1]
|
||||
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
||||
self.num_hidden_layers = len(key_value_pairs)
|
||||
|
||||
cache = StaticCache(
|
||||
_config(),
|
||||
max_batch_size=key_value_pairs[0][0].shape[0],
|
||||
device=key_value_pairs[0][0].device,
|
||||
dtype=key_value_pairs[0][0].dtype,
|
||||
max_cache_len=key_value_pairs[0][0].shape[2],
|
||||
)
|
||||
for i in range(len(key_value_pairs)):
|
||||
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
||||
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
||||
return cache
|
||||
|
||||
|
||||
def _unflatten_static_cache(
|
||||
values,
|
||||
context: torch.utils._pytree.Context,
|
||||
) -> StaticCache:
|
||||
return _make_static_cache(list(zip(values[0], values[1])))
|
||||
|
||||
|
||||
if is_torch_greater_or_equal("2.7"):
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
StaticCache,
|
||||
_flatten_static_cache,
|
||||
_unflatten_static_cache,
|
||||
serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}",
|
||||
flatten_with_keys_fn=_flatten_with_keys_static_cache,
|
||||
)
|
||||
|
||||
|
||||
class SlidingWindowCache(StaticCache):
|
||||
"""
|
||||
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
|
||||
|
@ -28,6 +28,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
|
||||
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
|
||||
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
|
||||
|
@ -194,6 +194,39 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
||||
def test_unflatten_flatten_static_cache(self):
|
||||
def make_static_cache(key_value_pairs):
|
||||
class _config:
|
||||
def __init__(self):
|
||||
self.head_dim = key_value_pairs[0][0].shape[-1]
|
||||
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
||||
self.num_hidden_layers = len(key_value_pairs)
|
||||
|
||||
cache = StaticCache(
|
||||
_config(),
|
||||
max_batch_size=key_value_pairs[0][0].shape[0],
|
||||
device=key_value_pairs[0][0].device,
|
||||
dtype=key_value_pairs[0][0].dtype,
|
||||
max_cache_len=key_value_pairs[0][0].shape[2],
|
||||
)
|
||||
for i in range(len(key_value_pairs)):
|
||||
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
||||
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
||||
return cache
|
||||
|
||||
cache = make_static_cache(
|
||||
[
|
||||
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
||||
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
||||
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
||||
]
|
||||
)
|
||||
flat, spec = torch.utils._pytree.tree_flatten(cache)
|
||||
self.assertIsInstance(flat, list)
|
||||
self.assertEqual(len(flat), 6)
|
||||
cache2 = torch.utils._pytree.tree_unflatten(flat, spec)
|
||||
self.assertTrue(isinstance(cache2, StaticCache))
|
||||
|
||||
|
||||
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
|
||||
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
||||
@ -726,6 +759,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
"""
|
||||
Tests that static cache works with `torch.export()`
|
||||
"""
|
||||
# TODO: Another test should be implemented to follow the same pattern
|
||||
# as the one implemented for DynamicCache.
|
||||
if not is_torch_greater_or_equal("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user