Support tracable dynamicKVcache (#36311)

* Support tracable dynamicKVcache

* Fix lint

* More fine grained test

* Lint

* Update

* Update

* Fix up

* Apply suggestions from code review

* Update src/transformers/cache_utils.py

* Update tests/utils/test_cache_utils.py

* Apply suggestions from code review

* Update

* Change error message

* Rename

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

---------

Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2025-03-19 12:52:30 -04:00 committed by GitHub
parent 63c3116530
commit f39f4960f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 116 additions and 4 deletions

View File

@ -8,6 +8,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from packaging import version
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
@ -535,6 +537,64 @@ class DynamicCache(Cache):
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
def _flatten_dynamic_cache(
dynamic_cache: DynamicCache,
):
"""Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume"""
if not isinstance(dynamic_cache, DynamicCache):
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
if not is_torch_greater_or_equal_than_2_6:
logger.warning_once(
"DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
)
# NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking
dictionary = {
"key_cache": getattr(dynamic_cache, "key_cache"),
"value_cache": getattr(dynamic_cache, "value_cache"),
}
return torch.utils._pytree._dict_flatten(dictionary)
def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
dictionary = {
"key_cache": getattr(dynamic_cache, "key_cache"),
"value_cache": getattr(dynamic_cache, "value_cache"),
}
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
def _unflatten_dynamic_cache(
values,
context: torch.utils._pytree.Context,
):
dictionary = torch.utils._pytree._dict_unflatten(values, context)
cache = DynamicCache()
for k, v in dictionary.items():
setattr(cache, k, v)
return cache
def _flatten_dynamic_cache_for_fx(cache, spec):
dictionary = {
"key_cache": getattr(cache, "key_cache"),
"value_cache": getattr(cache, "value_cache"),
}
return torch.utils._pytree.tree_flatten(dictionary)[0]
torch.utils._pytree.register_pytree_node(
DynamicCache,
_flatten_dynamic_cache,
_unflatten_dynamic_cache,
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
)
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
class OffloadedCache(DynamicCache):
"""
A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory.

View File

@ -31,6 +31,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_2_6 = parsed_torch_version_base >= version.parse("2.6")
is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4")
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
@ -46,10 +47,7 @@ _torch_distributed_available = torch.distributed.is_available()
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
def softmax_backward_data(parent, grad_output, output, dim, self):

View File

@ -174,6 +174,60 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
def test_dynamic_cache_exportability(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
prompt = "What is the best way to debug python script?"
inputs = tokenizer(prompt, return_tensors="pt")
attention_mask = inputs.attention_mask
input_ids = inputs.input_ids
past_key_values = DynamicCache()
ep = torch.export.export(
model,
(),
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
},
strict=False,
)
res = ep.module()(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
len(
[
x
for x in ep.graph_signature.input_specs
if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
]
),
)
past_key_values_eager = DynamicCache()
res_eager = model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values_eager,
use_cache=True,
)
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
@slow
@require_read_token
def test_static_cache_exportability(self):