transformers/tests/utils/test_cache_utils.py

698 lines
32 KiB
Python

# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
from parameterized import parameterized
from transformers import set_seed
from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS
from transformers.testing_utils import (
CaptureStderr,
cleanup,
get_gpu_count,
is_torch_available,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal
if is_torch_available():
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Cache,
ClvpForCausalLM,
DynamicCache,
GenerationConfig,
LlamaConfig,
StaticCache,
convert_and_export_with_cache,
)
TEST_CACHE_IMPLEMENTATIONS = [
cache_name
for cache_name in ALL_CACHE_IMPLEMENTATIONS
# TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`?
if cache_name != "mamba"
# TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them
if cache_name != "offloaded_hybrid"
]
@require_torch
class CacheTest(unittest.TestCase):
"""Cache tests that don't require loading models"""
def test_dynamic_cache_retrocompatibility(self):
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
legacy_cache = ()
new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
new_key = torch.rand((2, 4, 8, 16))
new_value = torch.rand((2, 4, 8, 16))
new_cache.update(new_key, new_value, layer_idx)
legacy_cache += ((new_key, new_value),)
# Sanity check 1: they must have the same shapes
self.assertTrue(len(legacy_cache), len(new_cache))
for layer_idx in range(10):
self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx]))
for key_value_idx in range(2):
self.assertTrue(
legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape
)
# Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the
# expected value
self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8)
# Sanity check 3: they must be equal, and both support indexing
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
)
# Test 1: We can convert from legacy to new with no changes
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
)
# Test 2: We can convert from new to legacy with no changes
to_legacy = new_cache.to_legacy_cache()
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
)
def test_reorder_cache_retrocompatibility(self):
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_cache = ()
new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
new_key = torch.rand((4, 4, 8, 16))
new_value = torch.rand((4, 4, 8, 16))
new_cache.update(new_key, new_value, layer_idx)
legacy_cache += ((new_key, new_value),)
# Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
# and batch_size=1
beam_idx = torch.randint(low=0, high=4, size=(4,))
legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
new_cache.reorder_cache(beam_idx)
# Let's check that the results are the same
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
torch.allclose(
new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
)
)
def test_static_cache_mha_mqa_gqa(self):
"""
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
attention (MQA)
"""
def _random_kvs(config):
# shape for key and values: (batch_size, num_heads, seq_len, head_dim)
random_keys = torch.rand(
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
device=torch_device,
)
random_values = torch.rand(
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
device=torch_device,
)
return random_keys, random_values
mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
class CacheIntegrationTest(unittest.TestCase):
"""Fast cache integration tests that share the same small model"""
@classmethod
def setUpClass(cls):
# Load once and reuse across tests
cls.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", padding_side="left")
cls.model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-135M-Instruct", device_map="auto", torch_dtype=torch.float16
)
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
def _skip_on_uninstalled_cache_dependencies(self, cache_implementation):
"""Function to skip tests on missing cache dependencies, given a cache implementation"""
if cache_implementation == "quantized" and not is_optimum_quanto_available():
self.skipTest("Quanto is not available")
if "offloaded" in cache_implementation:
has_accelerator = torch_device is not None and torch_device != "cpu"
if not has_accelerator:
self.skipTest("Offloaded caches require an accelerator")
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_batched(self, cache_implementation):
"""Sanity check: caches' `.update` function expects batched inputs"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
inputs = self.tokenizer(
["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
gen_out = self.model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
return_dict_in_generate=True,
cache_implementation=cache_implementation,
disable_compile=True,
)
# Sanity check: a cache was used
self.assertIsInstance(gen_out.past_key_values, Cache)
# Confirm that the output matches expectations
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_beam_search(self, cache_implementation):
"""
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
(an output sequence contains multiple beam indices).
"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
if cache_implementation == "offloaded_hybrid_chunked":
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
# different from the other caches.
self.skipTest("`offloaded_hybrid_chunked` fails this test")
EXPECTED_GENERATION = [
"Blue is the color of the sky, and the color of",
"Blue is the color of the sky, and the second is",
]
inputs = self.tokenizer(["Blue is"], return_tensors="pt").to(self.model.device)
gen_out = self.model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
num_beams=2,
num_return_sequences=2,
cache_implementation=cache_implementation,
disable_compile=True,
return_dict_in_generate=True,
)
# Sanity check: a cache was used
self.assertIsInstance(gen_out.past_key_values, Cache)
# At least one of the sequences requires multiple beam indices -> `reorder_cache` had to shift things around
self.assertTrue(any(len(set(beams_in_sequence)) > 1 for beams_in_sequence in gen_out.beam_indices))
# Confirm that the output matches expectations
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the cache"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
inputs = self.tokenizer(["The cat"], padding=True, return_tensors="pt").to(self.model.device)
generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
"cache_implementation": cache_implementation,
"disable_compile": True,
}
gen_out = self.model.generate(**inputs, **generation_kwargs)
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
# Now with extra left-padding
inputs_expanded = self.tokenizer(["The cat"], padding=True, return_tensors="pt", pad_to_multiple_of=32)
inputs_expanded = inputs_expanded.to(self.model.device)
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
gen_out = self.model.generate(**inputs_expanded, **generation_kwargs)
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@require_torch_accelerator
class CacheHardIntegrationTest(unittest.TestCase):
"""Hard cache integration tests that require loading different models"""
def tearDown(self):
# Some tests use large models, which might result in suboptimal torch re-allocation if we run multiple tests
# in a row
cleanup(torch_device, gc_collect=True)
@slow
def test_dynamic_cache_hard(self):
"""Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail"""
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = (
"Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like "
"to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n"
"Cats are also very independent. They don't like to be told what to do, and they don't like to be told "
"what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats "
"are also very curious. They like to explore, and they like to play. They are also very fast. They can "
"run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they "
"can solve problems. They are also very playful. They like to play with toys, and they like to play with "
"other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They "
"also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to "
"clean their litter box.\nCats are also very independent. They don't"
)
self.assertEqual(decoded[0], expected_text)
@parameterized.expand([("eager"), ("sdpa")])
@require_torch_gpu
@slow
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
"""Tests that different cache implementations work well with eager and SDPA inference"""
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
]
tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
).to(torch_device)
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, dynamic"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
gen_out = model.generate(
**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static", disable_compile=True
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
@require_torch_accelerator
@slow
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
"num_beams": 4,
"num_beam_groups": 2,
"num_return_sequences": 4,
"diversity_penalty": 1.0,
"max_new_tokens": 20,
"early_stopping": True,
}
original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
torch_accelerator_module = None
if device.type == "cuda":
torch_accelerator_module = torch.cuda
elif device.type == "xpu":
torch_accelerator_module = torch.xpu
torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=original, **inputs)
original_peak_memory = torch_accelerator_module.max_memory_allocated(device)
torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=offloaded, **inputs)
offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
self.assertTrue(offloaded_peak_memory < original_peak_memory)
@require_torch_gpu
@slow
def test_cache_copy(self):
"""Tests that we can manually set a cache, copy, and reuse it for generation"""
# TODO (joao): test for all cache implementations in `CacheIntegrationTest` after standardizing the
# lazy init of cache layers
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
prompt_cache = StaticCache(
config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16
)
INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
# This is the common prompt cached, we need to run forward without grad to be able to copy
with torch.no_grad():
prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values
prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
responses = []
for prompt in prompts:
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(
**new_inputs, past_key_values=past_key_values, max_new_tokens=40, disable_compile=True
)
response = tokenizer.batch_decode(outputs)[0]
responses.append(response)
EXPECTED_DECODED_TEXT = [
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is a wonderful "
"way to explore new places, cultures, and experiences. Whether you are a seasoned traveler or a "
"first-time adventurer, there is always something",
"You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital "
"of France.\n\n\n\n\n\n\n<|endoftext|>",
]
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
@require_torch_multi_gpu
def test_data_parallel_dynamic_cache(self):
"""
Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from
multiple `DynamicCache` in the gather step.
"""
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_repo)
# w/o DP: batch_size = num_gpu
# w DP: batch_size = 1 (with num_gpus replicas)
num_gpus = get_gpu_count()
model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device)
# w/o DP
no_parallelism_cache = model(**model_inputs).past_key_values
self.assertIsInstance(no_parallelism_cache, DynamicCache)
# w DP
model = torch.nn.DataParallel(model)
parallelism_cache = model(**model_inputs).past_key_values
self.assertIsInstance(parallelism_cache, DynamicCache)
# Check that the caches are the same
for layer_idx in range(len(no_parallelism_cache)):
for kv_idx in range(2): # 0 = key, 1 = value
torch.testing.assert_close(
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
)
@require_torch_gpu
def test_static_cache_no_cuda_graph_skips(self):
"""
Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543.
(? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead,
messages are being thrown to stderr?)
"""
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_repo)
inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device)
# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
with CaptureStderr() as cap:
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
self.assertNotIn("cuda", cap.err.lower())
@require_torch_multi_gpu
@slow
def test_static_cache_multi_gpu(self):
"""Regression test for #35164: static cache with multi-gpu"""
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
num_hidden_layers = 26
for i in range(num_hidden_layers):
device_map[f"model.layers.{i}"] = 0 if i < 13 else 1
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="bfloat16",
device_map=device_map,
)
inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
_ = model(**inputs)
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
@require_torch
class CacheExportIntegrationTest(unittest.TestCase):
"""Cache tests that rely on `torch.export()` and model loading"""
def test_dynamic_cache_exportability(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
prompt = "What is the best way to debug python script?"
inputs = tokenizer(prompt, return_tensors="pt")
attention_mask = inputs.attention_mask
input_ids = inputs.input_ids
past_key_values = DynamicCache()
ep = torch.export.export(
model,
(),
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
},
strict=False,
)
res = ep.module()(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
len(
[
x
for x in ep.graph_signature.input_specs
if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
]
),
)
past_key_values_eager = DynamicCache()
res_eager = model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values_eager,
use_cache=True,
)
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
def test_static_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
"""
if not is_torch_greater_or_equal("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
set_seed(0)
device = "cpu"
dtype = "bfloat16"
cache_implementation = "static"
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
batch_size = 1
max_cache_len = 1234
model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_cache_len,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
},
),
)
# Check if cache config is passed through correctly
self.assertEqual(model.generation_config.use_cache, True)
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
self.assertEqual(model.generation_config.max_length, max_cache_len)
self.assertTrue(model.generation_config.cache_config is not None)
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
exported_program = convert_and_export_with_cache(model)
# Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("key_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("value_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
# Export with dynamic shapes using Dim.AUTO
tokenizer = AutoTokenizer.from_pretrained(model_id)
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
exported_program = convert_and_export_with_cache(
model,
example_input_ids=input_ids,
dynamic_shapes=dynamic_shapes,
strict=False,
)
def test_hybrid_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
"""
if not is_torch_greater_or_equal("2.6"):
self.skipTest(reason="This test requires torch >= 2.6 to run.")
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
set_seed(0)
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_id)
model.eval()
self.assertEqual(model.config.use_cache, True)
self.assertEqual(model.config.cache_implementation, "hybrid")
# Export + HybridCache
model.eval()
max_batch_size = 1
max_cache_len = 23
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
exported_program = exportable_module.export()
n_g_key_caches = n_g_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("key_cache"):
self.assertTrue(buffer.shape[0] == max_batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_g_key_caches = n_g_key_caches + 1
if buffer_name.startswith("value_cache"):
self.assertTrue(buffer.shape[0] == max_batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_g_value_caches = n_g_value_caches + 1
self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)
# Export with dynamic shapes using Dim.AUTO
tokenizer = AutoTokenizer.from_pretrained(model_id)
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
exported_program = exportable_module.export(
input_ids=input_ids,
dynamic_shapes=dynamic_shapes,
strict=False,
)