mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Merge 09ef0d16ee
into 2d561713f8
This commit is contained in:
commit
d13a543f94
@ -9,7 +9,7 @@ from typing import Any, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6, is_torch_greater_or_equal_than_2_7
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
|
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
|
||||||
@ -651,9 +651,7 @@ class DynamicCache(Cache):
|
|||||||
|
|
||||||
|
|
||||||
# Utilities for `DynamicCache` <> torch.export support
|
# Utilities for `DynamicCache` <> torch.export support
|
||||||
def _flatten_dynamic_cache(
|
def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
|
||||||
dynamic_cache: DynamicCache,
|
|
||||||
):
|
|
||||||
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
|
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
|
||||||
if not isinstance(dynamic_cache, DynamicCache):
|
if not isinstance(dynamic_cache, DynamicCache):
|
||||||
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
|
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
|
||||||
@ -682,7 +680,7 @@ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
|
|||||||
def _unflatten_dynamic_cache(
|
def _unflatten_dynamic_cache(
|
||||||
values,
|
values,
|
||||||
context: torch.utils._pytree.Context,
|
context: torch.utils._pytree.Context,
|
||||||
):
|
) -> DynamicCache:
|
||||||
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
||||||
cache = DynamicCache()
|
cache = DynamicCache()
|
||||||
for k, v in dictionary.items():
|
for k, v in dictionary.items():
|
||||||
@ -1230,6 +1228,69 @@ class StaticCache(Cache):
|
|||||||
return kv_length, 0
|
return kv_length, 0
|
||||||
|
|
||||||
|
|
||||||
|
# Utilities for `StaticCache` <> torch.export support
|
||||||
|
def _flatten_static_cache(cache: StaticCache):
|
||||||
|
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
|
||||||
|
if not isinstance(cache, StaticCache):
|
||||||
|
raise RuntimeError("This pytree flattening function should only be applied to StaticCache")
|
||||||
|
|
||||||
|
if not is_torch_greater_or_equal_than_2_7:
|
||||||
|
logger.warning_once(
|
||||||
|
"StaticCache + torch.export is tested on torch 2.7.0+ and may not work on earlier versions."
|
||||||
|
)
|
||||||
|
|
||||||
|
dictionary = {
|
||||||
|
"key_cache": getattr(cache, "key_cache"),
|
||||||
|
"value_cache": getattr(cache, "value_cache"),
|
||||||
|
}
|
||||||
|
return torch.utils._pytree._dict_flatten(dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_with_keys_static_cache(cache: StaticCache):
|
||||||
|
dictionary = {
|
||||||
|
"key_cache": getattr(cache, "key_cache"),
|
||||||
|
"value_cache": getattr(cache, "value_cache"),
|
||||||
|
}
|
||||||
|
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_static_cache(key_value_pairs):
|
||||||
|
class _config:
|
||||||
|
def __init__(self):
|
||||||
|
self.head_dim = key_value_pairs[0][0].shape[-1]
|
||||||
|
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
||||||
|
self.num_hidden_layers = len(key_value_pairs)
|
||||||
|
|
||||||
|
cache = StaticCache(
|
||||||
|
_config(),
|
||||||
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
||||||
|
device=key_value_pairs[0][0].device,
|
||||||
|
dtype=key_value_pairs[0][0].dtype,
|
||||||
|
max_cache_len=key_value_pairs[0][0].shape[2],
|
||||||
|
)
|
||||||
|
for i in range(len(key_value_pairs)):
|
||||||
|
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
||||||
|
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def _unflatten_static_cache(
|
||||||
|
values,
|
||||||
|
context: torch.utils._pytree.Context,
|
||||||
|
) -> StaticCache:
|
||||||
|
return _make_static_cache(list(zip(values[0], values[1])))
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_greater_or_equal("2.7"):
|
||||||
|
torch.utils._pytree.register_pytree_node(
|
||||||
|
StaticCache,
|
||||||
|
_flatten_static_cache,
|
||||||
|
_unflatten_static_cache,
|
||||||
|
serialized_type_name=f"{StaticCache.__module__}.{StaticCache.__name__}",
|
||||||
|
flatten_with_keys_fn=_flatten_with_keys_static_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SlidingWindowCache(StaticCache):
|
class SlidingWindowCache(StaticCache):
|
||||||
"""
|
"""
|
||||||
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
|
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
|
||||||
|
@ -28,6 +28,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
|
||||||
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||||
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
|
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
|
||||||
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
|
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
|
||||||
|
@ -194,6 +194,39 @@ class CacheTest(unittest.TestCase):
|
|||||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||||
|
|
||||||
|
def test_unflatten_flatten_static_cache(self):
|
||||||
|
def make_static_cache(key_value_pairs):
|
||||||
|
class _config:
|
||||||
|
def __init__(self):
|
||||||
|
self.head_dim = key_value_pairs[0][0].shape[-1]
|
||||||
|
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
||||||
|
self.num_hidden_layers = len(key_value_pairs)
|
||||||
|
|
||||||
|
cache = StaticCache(
|
||||||
|
_config(),
|
||||||
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
||||||
|
device=key_value_pairs[0][0].device,
|
||||||
|
dtype=key_value_pairs[0][0].dtype,
|
||||||
|
max_cache_len=key_value_pairs[0][0].shape[2],
|
||||||
|
)
|
||||||
|
for i in range(len(key_value_pairs)):
|
||||||
|
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
||||||
|
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
||||||
|
return cache
|
||||||
|
|
||||||
|
cache = make_static_cache(
|
||||||
|
[
|
||||||
|
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
||||||
|
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
||||||
|
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
flat, spec = torch.utils._pytree.tree_flatten(cache)
|
||||||
|
self.assertIsInstance(flat, list)
|
||||||
|
self.assertEqual(len(flat), 6)
|
||||||
|
cache2 = torch.utils._pytree.tree_unflatten(flat, spec)
|
||||||
|
self.assertTrue(isinstance(cache2, StaticCache))
|
||||||
|
|
||||||
|
|
||||||
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
|
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
|
||||||
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
||||||
@ -726,6 +759,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Tests that static cache works with `torch.export()`
|
Tests that static cache works with `torch.export()`
|
||||||
"""
|
"""
|
||||||
|
# TODO: Another test should be implemented to follow the same pattern
|
||||||
|
# as the one implemented for DynamicCache.
|
||||||
if not is_torch_greater_or_equal("2.3"):
|
if not is_torch_greater_or_equal("2.3"):
|
||||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user