mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Support tracable dynamicKVcache (#36311)
* Support tracable dynamicKVcache * Fix lint * More fine grained test * Lint * Update * Update * Fix up * Apply suggestions from code review * Update src/transformers/cache_utils.py * Update tests/utils/test_cache_utils.py * Apply suggestions from code review * Update * Change error message * Rename * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
63c3116530
commit
f39f4960f3
@ -8,6 +8,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
|
||||
|
||||
@ -535,6 +537,64 @@ class DynamicCache(Cache):
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
||||
|
||||
|
||||
def _flatten_dynamic_cache(
|
||||
dynamic_cache: DynamicCache,
|
||||
):
|
||||
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
|
||||
if not isinstance(dynamic_cache, DynamicCache):
|
||||
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
|
||||
|
||||
if not is_torch_greater_or_equal_than_2_6:
|
||||
logger.warning_once(
|
||||
"DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
|
||||
)
|
||||
|
||||
# NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking
|
||||
dictionary = {
|
||||
"key_cache": getattr(dynamic_cache, "key_cache"),
|
||||
"value_cache": getattr(dynamic_cache, "value_cache"),
|
||||
}
|
||||
return torch.utils._pytree._dict_flatten(dictionary)
|
||||
|
||||
|
||||
def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
|
||||
dictionary = {
|
||||
"key_cache": getattr(dynamic_cache, "key_cache"),
|
||||
"value_cache": getattr(dynamic_cache, "value_cache"),
|
||||
}
|
||||
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
||||
|
||||
|
||||
def _unflatten_dynamic_cache(
|
||||
values,
|
||||
context: torch.utils._pytree.Context,
|
||||
):
|
||||
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
||||
cache = DynamicCache()
|
||||
for k, v in dictionary.items():
|
||||
setattr(cache, k, v)
|
||||
return cache
|
||||
|
||||
|
||||
def _flatten_dynamic_cache_for_fx(cache, spec):
|
||||
dictionary = {
|
||||
"key_cache": getattr(cache, "key_cache"),
|
||||
"value_cache": getattr(cache, "value_cache"),
|
||||
}
|
||||
return torch.utils._pytree.tree_flatten(dictionary)[0]
|
||||
|
||||
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
DynamicCache,
|
||||
_flatten_dynamic_cache,
|
||||
_unflatten_dynamic_cache,
|
||||
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
||||
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
|
||||
)
|
||||
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
|
||||
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
|
||||
|
||||
|
||||
class OffloadedCache(DynamicCache):
|
||||
"""
|
||||
A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory.
|
||||
|
@ -31,6 +31,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
||||
|
||||
is_torch_greater_or_equal_than_2_6 = parsed_torch_version_base >= version.parse("2.6")
|
||||
is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4")
|
||||
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
|
||||
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
||||
@ -46,10 +47,7 @@ _torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
from torch.distributed.tensor import Replicate
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
|
||||
|
||||
|
||||
def softmax_backward_data(parent, grad_output, output, dim, self):
|
||||
|
@ -174,6 +174,60 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
||||
def test_dynamic_cache_exportability(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
model = model.eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
prompt = "What is the best way to debug python script?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
attention_mask = inputs.attention_mask
|
||||
input_ids = inputs.input_ids
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
(),
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": True,
|
||||
},
|
||||
strict=False,
|
||||
)
|
||||
res = ep.module()(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||
self.assertEqual(
|
||||
3,
|
||||
len(
|
||||
[
|
||||
x
|
||||
for x in ep.graph_signature.input_specs
|
||||
if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
past_key_values_eager = DynamicCache()
|
||||
res_eager = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values_eager,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
|
||||
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2))
|
||||
|
||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_static_cache_exportability(self):
|
||||
|
Loading…
Reference in New Issue
Block a user