From f39f4960f30e3eadd6d948e4dcb2da32eda253b5 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Wed, 19 Mar 2025 12:52:30 -0400 Subject: [PATCH] Support tracable dynamicKVcache (#36311) * Support tracable dynamicKVcache * Fix lint * More fine grained test * Lint * Update * Update * Fix up * Apply suggestions from code review * Update src/transformers/cache_utils.py * Update tests/utils/test_cache_utils.py * Apply suggestions from code review * Update * Change error message * Rename * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: Joao Gante --- src/transformers/cache_utils.py | 60 +++++++++++++++++++++++++++++++ src/transformers/pytorch_utils.py | 6 ++-- tests/utils/test_cache_utils.py | 54 ++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 4 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 558bcfb2e28..d873a17d369 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -8,6 +8,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from packaging import version +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 + from .configuration_utils import PretrainedConfig from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging @@ -535,6 +537,64 @@ class DynamicCache(Cache): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +def _flatten_dynamic_cache( + dynamic_cache: DynamicCache, +): + """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" + if not isinstance(dynamic_cache, DynamicCache): + raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") + + if not is_torch_greater_or_equal_than_2_6: + logger.warning_once( + "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." + ) + + # NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking + dictionary = { + "key_cache": getattr(dynamic_cache, "key_cache"), + "value_cache": getattr(dynamic_cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): + dictionary = { + "key_cache": getattr(dynamic_cache, "key_cache"), + "value_cache": getattr(dynamic_cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def _unflatten_dynamic_cache( + values, + context: torch.utils._pytree.Context, +): + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = DynamicCache() + for k, v in dictionary.items(): + setattr(cache, k, v) + return cache + + +def _flatten_dynamic_cache_for_fx(cache, spec): + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree.tree_flatten(dictionary)[0] + + +torch.utils._pytree.register_pytree_node( + DynamicCache, + _flatten_dynamic_cache, + _unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, +) +# TODO (tmanlaibaatar) This won't be needed in torch 2.7. +torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) + + class OffloadedCache(DynamicCache): """ A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index d81c74be95c..d16d07f597f 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -31,6 +31,7 @@ logger = logging.get_logger(__name__) parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) +is_torch_greater_or_equal_than_2_6 = parsed_torch_version_base >= version.parse("2.6") is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4") is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3") is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") @@ -46,10 +47,7 @@ _torch_distributed_available = torch.distributed.is_available() if is_torch_greater_or_equal("2.5") and _torch_distributed_available: from torch.distributed.tensor import Replicate - from torch.distributed.tensor.parallel import ( - ColwiseParallel, - RowwiseParallel, - ) + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel def softmax_backward_data(parent, grad_output, output, dim, self): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index efe4e6af5c1..fd5b720d900 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -174,6 +174,60 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128)) + def test_dynamic_cache_exportability(self): + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model = model.eval() + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + prompt = "What is the best way to debug python script?" + inputs = tokenizer(prompt, return_tensors="pt") + attention_mask = inputs.attention_mask + input_ids = inputs.input_ids + + past_key_values = DynamicCache() + ep = torch.export.export( + model, + (), + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": True, + }, + strict=False, + ) + res = ep.module()( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers) + self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs)) + self.assertEqual( + 3, + len( + [ + x + for x in ep.graph_signature.input_specs + if x.kind == torch.export.graph_signature.InputKind.USER_INPUT + ] + ), + ) + + past_key_values_eager = DynamicCache() + res_eager = model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values_eager, + use_cache=True, + ) + self.assertTrue(torch.allclose(res.logits, res_eager.logits)) + for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache): + self.assertTrue(torch.allclose(k1, k2)) + + for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): + self.assertTrue(torch.allclose(v1, v2)) + @slow @require_read_token def test_static_cache_exportability(self):