mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00

* Unbreak optimum-executorch * use static cache if has layer_types but no sliding_window * revert view on kv_arange --------- Co-authored-by: Guang Yang <guangyang@fb.com>
1026 lines
46 KiB
Python
1026 lines
46 KiB
Python
# Copyright 2023 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import unittest
|
|
|
|
from packaging import version
|
|
from parameterized import parameterized
|
|
|
|
from transformers import set_seed
|
|
from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS
|
|
from transformers.testing_utils import (
|
|
CaptureStderr,
|
|
backend_device_count,
|
|
backend_torch_accelerator_module,
|
|
cleanup,
|
|
get_gpu_count,
|
|
is_torch_available,
|
|
require_read_token,
|
|
require_torch,
|
|
require_torch_accelerator,
|
|
require_torch_gpu,
|
|
require_torch_multi_accelerator,
|
|
require_torch_multi_gpu,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
Cache,
|
|
ClvpForCausalLM,
|
|
DynamicCache,
|
|
Gemma2Config,
|
|
GenerationConfig,
|
|
HybridCache,
|
|
LlamaConfig,
|
|
SlidingWindowCache,
|
|
StaticCache,
|
|
convert_and_export_with_cache,
|
|
pipeline,
|
|
)
|
|
from transformers.integrations.executorch import export_with_dynamic_cache
|
|
|
|
|
|
TEST_CACHE_IMPLEMENTATIONS = [
|
|
cache_name
|
|
for cache_name in ALL_CACHE_IMPLEMENTATIONS
|
|
# TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`?
|
|
if cache_name != "mamba"
|
|
# TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them
|
|
if cache_name != "offloaded_hybrid"
|
|
]
|
|
|
|
|
|
@require_torch
|
|
class CacheTest(unittest.TestCase):
|
|
"""Cache tests that don't require loading models"""
|
|
|
|
def test_dynamic_cache_retrocompatibility(self):
|
|
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
|
|
legacy_cache = ()
|
|
new_cache = DynamicCache()
|
|
|
|
# Creates a new cache with 10 layers in both formats
|
|
for layer_idx in range(10):
|
|
new_key = torch.rand((2, 4, 8, 16))
|
|
new_value = torch.rand((2, 4, 8, 16))
|
|
new_cache.update(new_key, new_value, layer_idx)
|
|
legacy_cache += ((new_key, new_value),)
|
|
|
|
# Sanity check 1: they must have the same shapes
|
|
self.assertTrue(len(legacy_cache), len(new_cache))
|
|
for layer_idx in range(10):
|
|
self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx]))
|
|
for key_value_idx in range(2):
|
|
self.assertTrue(
|
|
legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape
|
|
)
|
|
|
|
# Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the
|
|
# expected value
|
|
self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8)
|
|
|
|
# Sanity check 3: they must be equal, and both support indexing
|
|
for layer_idx in range(10):
|
|
for key_value_idx in range(2):
|
|
self.assertTrue(
|
|
torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
|
|
)
|
|
|
|
# Test 1: We can convert from legacy to new with no changes
|
|
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
|
|
for layer_idx in range(10):
|
|
for key_value_idx in range(2):
|
|
self.assertTrue(
|
|
torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
|
|
)
|
|
|
|
# Test 2: We can convert from new to legacy with no changes
|
|
to_legacy = new_cache.to_legacy_cache()
|
|
for layer_idx in range(10):
|
|
for key_value_idx in range(2):
|
|
self.assertTrue(
|
|
torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
|
|
)
|
|
|
|
def test_reorder_cache_retrocompatibility(self):
|
|
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
|
legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
|
|
|
|
legacy_cache = ()
|
|
new_cache = DynamicCache()
|
|
|
|
# Creates a new cache with 10 layers in both formats
|
|
for layer_idx in range(10):
|
|
new_key = torch.rand((4, 4, 8, 16))
|
|
new_value = torch.rand((4, 4, 8, 16))
|
|
new_cache.update(new_key, new_value, layer_idx)
|
|
legacy_cache += ((new_key, new_value),)
|
|
|
|
# Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
|
|
# and batch_size=1
|
|
beam_idx = torch.randint(low=0, high=4, size=(4,))
|
|
|
|
legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
|
|
new_cache.reorder_cache(beam_idx)
|
|
|
|
# Let's check that the results are the same
|
|
for layer_idx in range(10):
|
|
for key_value_idx in range(2):
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
|
|
)
|
|
)
|
|
|
|
def test_static_cache_mha_mqa_gqa(self):
|
|
"""
|
|
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
|
|
attention (MQA)
|
|
"""
|
|
|
|
def _random_kvs(config):
|
|
# shape for key and values: (batch_size, num_heads, seq_len, head_dim)
|
|
random_keys = torch.rand(
|
|
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
|
|
device=torch_device,
|
|
)
|
|
random_values = torch.rand(
|
|
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
|
|
device=torch_device,
|
|
)
|
|
return random_keys, random_values
|
|
|
|
mha_config = LlamaConfig(num_attention_heads=32)
|
|
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
|
cached_keys, cached_values = mha_static_cache.update(
|
|
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
|
)
|
|
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
|
|
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
|
|
|
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
|
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
|
cached_keys, cached_values = gqa_static_cache.update(
|
|
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
|
)
|
|
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
|
|
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
|
|
|
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
|
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
|
cached_keys, cached_values = mqa_static_cache.update(
|
|
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
|
)
|
|
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
|
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
|
|
|
|
|
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
|
|
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
|
# Installed dependencies
|
|
if cache_implementation == "quantized" and not is_optimum_quanto_available():
|
|
test.skipTest("Quanto is not available")
|
|
# Devices
|
|
if "offloaded" in cache_implementation:
|
|
has_accelerator = torch_device is not None and torch_device != "cpu"
|
|
if not has_accelerator:
|
|
test.skipTest("Offloaded caches require an accelerator")
|
|
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
|
if backend_device_count(torch_device) != 1:
|
|
test.skipTest("Offloaded static caches require exactly 1 accelerator")
|
|
|
|
|
|
class CacheIntegrationTest(unittest.TestCase):
|
|
"""Fast cache integration tests that share the same small model"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
# Load once and reuse across tests
|
|
cls.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", padding_side="left")
|
|
cls.model = AutoModelForCausalLM.from_pretrained(
|
|
"HuggingFaceTB/SmolLM2-135M-Instruct", device_map="auto", torch_dtype=torch.float16
|
|
)
|
|
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
|
|
|
|
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
|
def test_cache_batched(self, cache_implementation):
|
|
"""Sanity check: caches' `.update` function expects batched inputs"""
|
|
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
|
|
|
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
|
|
|
inputs = self.tokenizer(
|
|
["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt"
|
|
)
|
|
inputs = inputs.to(self.model.device)
|
|
|
|
gen_out = self.model.generate(
|
|
**inputs,
|
|
do_sample=False,
|
|
max_new_tokens=10,
|
|
return_dict_in_generate=True,
|
|
cache_implementation=cache_implementation,
|
|
disable_compile=True,
|
|
)
|
|
# Sanity check: a cache was used
|
|
self.assertIsInstance(gen_out.past_key_values, Cache)
|
|
# Confirm that the output matches expectations
|
|
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
|
|
|
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
|
def test_cache_beam_search(self, cache_implementation):
|
|
"""
|
|
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
|
|
(an output sequence contains multiple beam indices).
|
|
"""
|
|
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
|
if cache_implementation == "offloaded_hybrid_chunked":
|
|
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
|
|
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
|
|
# different from the other caches.
|
|
self.skipTest("`offloaded_hybrid_chunked` fails this test")
|
|
|
|
EXPECTED_GENERATION = [
|
|
"Blue is the color of the sky, and the color of",
|
|
"Blue is the color of the sky, and the second is",
|
|
]
|
|
|
|
inputs = self.tokenizer(["Blue is"], return_tensors="pt").to(self.model.device)
|
|
gen_out = self.model.generate(
|
|
**inputs,
|
|
do_sample=False,
|
|
max_new_tokens=10,
|
|
num_beams=2,
|
|
num_return_sequences=2,
|
|
cache_implementation=cache_implementation,
|
|
disable_compile=True,
|
|
return_dict_in_generate=True,
|
|
)
|
|
# Sanity check: a cache was used
|
|
self.assertIsInstance(gen_out.past_key_values, Cache)
|
|
# At least one of the sequences requires multiple beam indices -> `reorder_cache` had to shift things around
|
|
self.assertTrue(any(len(set(beams_in_sequence)) > 1 for beams_in_sequence in gen_out.beam_indices))
|
|
# Confirm that the output matches expectations
|
|
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
|
|
|
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
|
def test_cache_extra_left_padding(self, cache_implementation):
|
|
"""Tests that adding extra left-padding does not affect the generation with the cache"""
|
|
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
|
|
|
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
|
|
|
|
inputs = self.tokenizer(["The cat"], padding=True, return_tensors="pt").to(self.model.device)
|
|
generation_kwargs = {
|
|
"do_sample": False,
|
|
"max_new_tokens": 10,
|
|
"cache_implementation": cache_implementation,
|
|
"disable_compile": True,
|
|
}
|
|
|
|
gen_out = self.model.generate(**inputs, **generation_kwargs)
|
|
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
|
|
|
# Now with extra left-padding
|
|
inputs_expanded = self.tokenizer(["The cat"], padding=True, return_tensors="pt", pad_to_multiple_of=32)
|
|
inputs_expanded = inputs_expanded.to(self.model.device)
|
|
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
|
|
gen_out = self.model.generate(**inputs_expanded, **generation_kwargs)
|
|
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
|
|
|
|
|
@require_torch_accelerator
|
|
class CacheHardIntegrationTest(unittest.TestCase):
|
|
"""Hard cache integration tests that require loading different models"""
|
|
|
|
def setUp(self):
|
|
# Clears memory before each test. Some tests use large models, which might result in suboptimal torch
|
|
# re-allocation if we run multiple tests in a row without clearing memory.
|
|
cleanup(torch_device, gc_collect=True)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
# Clears memory after the last test. See `setUp` for more details.
|
|
cleanup(torch_device, gc_collect=True)
|
|
|
|
@slow
|
|
def test_dynamic_cache_hard(self):
|
|
"""Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail"""
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
|
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="auto", torch_dtype=torch.bfloat16)
|
|
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
|
|
|
|
set_seed(0)
|
|
gen_out = model.generate(
|
|
**inputs, do_sample=True, max_new_tokens=256, return_dict_in_generate=True, output_scores=True
|
|
)
|
|
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
|
# sum of the scores for the generated tokens
|
|
input_length = inputs.input_ids.shape[1]
|
|
score_sum = sum(
|
|
[score[0][gen_out.sequences[0][input_length + idx]] for idx, score in enumerate(gen_out.scores)]
|
|
)
|
|
|
|
EXPECTED_GENERATION = (
|
|
"Here's everything I know about cats. Cats are mammals, they have four legs, they have a tail, they have "
|
|
"a face with a nose, eyes, and mouth. They have fur, they have claws, and they have a body that is "
|
|
"covered in fur. They are carnivores, so they eat meat. They are also very clean animals, they groom "
|
|
"themselves. They have a lot of different breeds. Some are small, some are large. Some are friendly, "
|
|
"some are not. They have a lot of different personalities. They can be very independent, or they can be "
|
|
"very affectionate. They can be very playful, or they can be very lazy. They can be very intelligent, or "
|
|
"they can be very silly. They have a lot of different behaviors. They can be very curious, or they can "
|
|
"be very cautious. They can be very vocal, or they can be very quiet. They can be very social, or they "
|
|
"can be very solitary. They can be very active, or they can be very inactive. They can be very "
|
|
"affectionate, or they can be very aloof. They can be very playful, or they can be very lazy. They can "
|
|
"be very intelligent, or they can be very silly. They have a lot of different behaviors. They can be "
|
|
"very curious, or they can"
|
|
)
|
|
EXPECTED_SCORE_SUM = 11017.4971
|
|
self.assertEqual(decoded[0], EXPECTED_GENERATION)
|
|
self.assertAlmostEqual(score_sum, EXPECTED_SCORE_SUM, places=2)
|
|
self.assertIsInstance(gen_out.past_key_values, DynamicCache) # sanity check
|
|
|
|
@parameterized.expand([("eager"), ("sdpa")])
|
|
@require_torch_accelerator
|
|
@slow
|
|
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
|
"""Tests that different cache implementations work well with eager and SDPA inference"""
|
|
EXPECTED_GENERATION = [
|
|
"The best color is the one that is most suitable for the purpose.",
|
|
"We should not undermind the issues at hand, but instead, we should focus on the things",
|
|
]
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
"Qwen/Qwen3-4B",
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation=attn_implementation,
|
|
device_map="auto",
|
|
)
|
|
inputs = tokenizer(
|
|
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
|
).to(model.device)
|
|
generation_kwargs = {"do_sample": False, "max_new_tokens": 10, "return_dict_in_generate": True}
|
|
|
|
set_seed(0)
|
|
gen_out = model.generate(**inputs, **generation_kwargs)
|
|
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
|
with self.subTest(f"{attn_implementation}, dynamic"):
|
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
|
self.assertIsInstance(gen_out.past_key_values, DynamicCache) # sanity check
|
|
|
|
set_seed(0)
|
|
gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static", disable_compile=True)
|
|
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
|
with self.subTest(f"{attn_implementation}, static, eager"):
|
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
|
self.assertIsInstance(gen_out.past_key_values, StaticCache) # sanity check
|
|
|
|
set_seed(0)
|
|
gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static")
|
|
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
|
with self.subTest(f"{attn_implementation}, static, compiled"):
|
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
|
self.assertIsInstance(gen_out.past_key_values, StaticCache) # sanity check
|
|
|
|
@require_torch_accelerator
|
|
@slow
|
|
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
|
|
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
|
|
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
|
device = model.device
|
|
|
|
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
|
|
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
|
|
|
input_text = "Fun fact:"
|
|
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
|
common = {
|
|
"num_beams": 4,
|
|
"num_beam_groups": 2,
|
|
"num_return_sequences": 4,
|
|
"diversity_penalty": 1.0,
|
|
"max_new_tokens": 20,
|
|
"early_stopping": True,
|
|
}
|
|
original = GenerationConfig(**common)
|
|
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
|
|
|
torch_accelerator_module = backend_torch_accelerator_module(device.type)
|
|
|
|
torch_accelerator_module.reset_peak_memory_stats(device)
|
|
model.generate(generation_config=original, **inputs)
|
|
original_peak_memory = torch_accelerator_module.max_memory_allocated(device)
|
|
torch_accelerator_module.reset_peak_memory_stats(device)
|
|
model.generate(generation_config=offloaded, **inputs)
|
|
offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
|
|
self.assertTrue(offloaded_peak_memory < original_peak_memory)
|
|
|
|
@require_torch_accelerator
|
|
@slow
|
|
def test_cache_copy(self):
|
|
"""Tests that we can manually set a cache, copy, and reuse it for generation"""
|
|
# TODO (joao): test for all cache implementations in `CacheIntegrationTest` after standardizing the
|
|
# lazy init of cache layers
|
|
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=torch_device, torch_dtype=torch.bfloat16)
|
|
|
|
prompt_cache = StaticCache(
|
|
config=model.config, max_batch_size=1, max_cache_len=1024, device=torch_device, dtype=torch.bfloat16
|
|
)
|
|
|
|
INITIAL_PROMPT = "You are a helpful assistant. "
|
|
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(torch_device)
|
|
# This is the common prompt cached, we need to run forward without grad to be able to copy
|
|
with torch.no_grad():
|
|
prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values
|
|
|
|
prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
|
|
responses = []
|
|
for prompt in prompts:
|
|
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to(torch_device)
|
|
past_key_values = copy.deepcopy(prompt_cache)
|
|
outputs = model.generate(
|
|
**new_inputs, past_key_values=past_key_values, max_new_tokens=40, disable_compile=True
|
|
)
|
|
response = tokenizer.batch_decode(outputs)[0]
|
|
responses.append(response)
|
|
|
|
EXPECTED_DECODED_TEXT = [
|
|
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an "
|
|
"enriching experience that broadens our horizons and allows us to explore the world beyond our comfort "
|
|
"zones. Whether it's a short weekend getaway",
|
|
"You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital "
|
|
"of France.\n\n\n\n\n\n\n<|endoftext|>",
|
|
]
|
|
|
|
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
|
|
|
@require_torch_multi_gpu
|
|
def test_data_parallel_dynamic_cache(self):
|
|
"""
|
|
Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from
|
|
multiple `DynamicCache` in the gather step.
|
|
"""
|
|
|
|
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
|
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
|
|
|
# w/o DP: batch_size = num_gpu
|
|
# w DP: batch_size = 1 (with num_gpus replicas)
|
|
num_gpus = get_gpu_count()
|
|
model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device)
|
|
|
|
# w/o DP
|
|
no_parallelism_cache = model(**model_inputs).past_key_values
|
|
self.assertIsInstance(no_parallelism_cache, DynamicCache)
|
|
|
|
# w DP
|
|
model = torch.nn.DataParallel(model)
|
|
parallelism_cache = model(**model_inputs).past_key_values
|
|
self.assertIsInstance(parallelism_cache, DynamicCache)
|
|
|
|
# Check that the caches are the same
|
|
for layer_idx in range(len(no_parallelism_cache)):
|
|
for kv_idx in range(2): # 0 = key, 1 = value
|
|
torch.testing.assert_close(
|
|
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
|
|
)
|
|
|
|
@require_torch_gpu
|
|
def test_static_cache_no_cuda_graph_skips(self):
|
|
"""
|
|
Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543.
|
|
|
|
(? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead,
|
|
messages are being thrown to stderr?)
|
|
"""
|
|
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
|
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
|
inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device)
|
|
|
|
# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
|
|
with CaptureStderr() as cap:
|
|
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
|
|
self.assertNotIn("cuda", cap.err.lower())
|
|
|
|
@require_torch_multi_accelerator
|
|
@slow
|
|
@require_read_token
|
|
def test_static_cache_multi_accelerator(self):
|
|
"""Regression test for #35164: static cache with multi-accelerator"""
|
|
|
|
model_id = "google/gemma-2-2b-it"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
|
|
num_hidden_layers = 26
|
|
for i in range(num_hidden_layers):
|
|
device_map[f"model.layers.{i}"] = 0 if i < 13 else 1
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
torch_dtype="bfloat16",
|
|
device_map=device_map,
|
|
)
|
|
inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
|
|
_ = model(**inputs)
|
|
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
|
|
|
|
@require_torch_gpu
|
|
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
|
def test_cache_gptj_model(self, cache_implementation):
|
|
"""Tests caches with GPT-J model. Regression test for https://github.com/huggingface/transformers/pull/34799"""
|
|
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
|
|
|
model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM"
|
|
pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16)
|
|
pipe.model.config.sliding_window = (
|
|
256 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None
|
|
)
|
|
out = pipe(
|
|
"hello world",
|
|
cache_implementation=cache_implementation,
|
|
max_new_tokens=10,
|
|
do_sample=False,
|
|
disable_compile=True,
|
|
return_tensors=True,
|
|
)[0]["generated_token_ids"][-10:]
|
|
EXPECTED_OUTPUT = [879, 175, 39, 141, 1000, 975, 951, 991, 683, 441]
|
|
self.assertListEqual(out, EXPECTED_OUTPUT)
|
|
|
|
|
|
@require_torch
|
|
class CacheExportIntegrationTest(unittest.TestCase):
|
|
"""Cache tests that rely on `torch.export()` and model loading"""
|
|
|
|
def test_dynamic_cache_exportability(self):
|
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
|
model = model.eval()
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
|
prompt = "What is the best way to debug python script?"
|
|
inputs = tokenizer(prompt, return_tensors="pt")
|
|
attention_mask = inputs.attention_mask
|
|
input_ids = inputs.input_ids
|
|
|
|
ep = export_with_dynamic_cache(model, input_ids, attention_mask)
|
|
res = ep.module()(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=DynamicCache(),
|
|
use_cache=True,
|
|
)
|
|
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
|
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
|
self.assertEqual(
|
|
3,
|
|
len(
|
|
[
|
|
x
|
|
for x in ep.graph_signature.input_specs
|
|
if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
]
|
|
),
|
|
)
|
|
|
|
past_key_values_eager = DynamicCache()
|
|
res_eager = model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values_eager,
|
|
use_cache=True,
|
|
)
|
|
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
|
|
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
|
|
self.assertTrue(torch.allclose(k1, k2))
|
|
|
|
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
|
self.assertTrue(torch.allclose(v1, v2))
|
|
|
|
def test_static_cache_exportability(self):
|
|
"""
|
|
Tests that static cache works with `torch.export()`
|
|
"""
|
|
if not is_torch_greater_or_equal("2.3"):
|
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
|
|
|
set_seed(0)
|
|
device = "cpu"
|
|
dtype = "bfloat16"
|
|
cache_implementation = "static"
|
|
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
|
batch_size = 1
|
|
max_cache_len = 1234
|
|
model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
device_map=device,
|
|
torch_dtype=dtype,
|
|
attn_implementation=attn_implementation,
|
|
generation_config=GenerationConfig(
|
|
use_cache=True,
|
|
cache_implementation=cache_implementation,
|
|
max_length=max_cache_len,
|
|
cache_config={
|
|
"batch_size": batch_size,
|
|
"max_cache_len": max_cache_len,
|
|
"device": device,
|
|
},
|
|
),
|
|
)
|
|
# Check if cache config is passed through correctly
|
|
self.assertEqual(model.generation_config.use_cache, True)
|
|
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
|
|
self.assertEqual(model.generation_config.max_length, max_cache_len)
|
|
self.assertTrue(model.generation_config.cache_config is not None)
|
|
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
|
|
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
|
|
|
|
exported_program = convert_and_export_with_cache(model)
|
|
|
|
# Check if the exported model is configured with the `StaticCache` correctly
|
|
n_static_key_caches = n_static_value_caches = 0
|
|
for buffer_name, buffer in exported_program.named_buffers():
|
|
if buffer_name.startswith("key_cache"):
|
|
self.assertTrue(buffer.shape[0] == batch_size)
|
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
|
n_static_key_caches = n_static_key_caches + 1
|
|
if buffer_name.startswith("value_cache"):
|
|
self.assertTrue(buffer.shape[0] == batch_size)
|
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
|
n_static_value_caches = n_static_value_caches + 1
|
|
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
|
|
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
|
|
|
# Export with dynamic shapes
|
|
input_ids = torch.zeros((1, 3), dtype=torch.long)
|
|
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
|
|
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
|
|
strict = version.parse(torch.__version__) != version.parse("2.7.0")
|
|
exported_program = convert_and_export_with_cache(
|
|
model,
|
|
example_input_ids=input_ids,
|
|
example_cache_position=cache_position,
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=strict,
|
|
)
|
|
|
|
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
|
|
|
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
|
exported_program = exportable_module.export(
|
|
input_ids=input_ids,
|
|
cache_position=cache_position,
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=strict,
|
|
)
|
|
|
|
def test_hybrid_cache_exportability(self):
|
|
"""
|
|
Tests that static cache works with `torch.export()`
|
|
"""
|
|
if not is_torch_greater_or_equal("2.6"):
|
|
self.skipTest(reason="This test requires torch >= 2.6 to run.")
|
|
|
|
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
|
|
|
set_seed(0)
|
|
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
|
model.eval()
|
|
self.assertEqual(model.config.use_cache, True)
|
|
self.assertEqual(model.config.cache_implementation, "hybrid")
|
|
|
|
# Export + HybridCache
|
|
model.eval()
|
|
max_batch_size = 1
|
|
max_cache_len = 23
|
|
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
|
|
exported_program = exportable_module.export()
|
|
n_g_key_caches = n_g_value_caches = 0
|
|
for buffer_name, buffer in exported_program.named_buffers():
|
|
if buffer_name.startswith("key_cache"):
|
|
self.assertTrue(buffer.shape[0] == max_batch_size)
|
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
|
n_g_key_caches = n_g_key_caches + 1
|
|
if buffer_name.startswith("value_cache"):
|
|
self.assertTrue(buffer.shape[0] == max_batch_size)
|
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
|
n_g_value_caches = n_g_value_caches + 1
|
|
self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
|
|
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)
|
|
|
|
# Export with dynamic shapes using Dim.AUTO
|
|
input_ids = torch.zeros((1, 3), dtype=torch.long)
|
|
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
|
|
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
|
|
strict = version.parse(torch.__version__) != version.parse("2.7.0")
|
|
exported_program = exportable_module.export(
|
|
input_ids=input_ids,
|
|
cache_position=cache_position,
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=strict,
|
|
)
|
|
|
|
|
|
class SyntheticCacheTest(unittest.TestCase):
|
|
"""Tests cache behavior with simple dummy data."""
|
|
|
|
def setUp(self):
|
|
"""Set up common configuration and cache instances for all tests."""
|
|
self.window_size = 4
|
|
self.max_cache_len = 4
|
|
self.config = Gemma2Config(
|
|
num_hidden_layers=1,
|
|
num_key_value_heads=1,
|
|
num_attention_heads=1,
|
|
head_dim=1,
|
|
hidden_size=1,
|
|
sliding_window=self.window_size,
|
|
sliding_window_pattern=2, # Default pattern for hybrid sliding
|
|
)
|
|
|
|
def test_static_cache_out_of_bounds(self):
|
|
"""Test StaticCache raises IndexError for out-of-bounds positions."""
|
|
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
|
pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len
|
|
|
|
with self.assertRaises(IndexError):
|
|
static_cache.update(
|
|
key_states=torch.tensor([[[[1.0]]]]),
|
|
value_states=torch.tensor([[[[1.0]]]]),
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": pos_out_of_bounds},
|
|
)
|
|
|
|
def test_static_cache(self):
|
|
"""Test StaticCache with manually prefilled states and hardcoded assertions.
|
|
|
|
Scenario 1: Fill up to near capacity
|
|
prefill: [1.0, 2.0, 0.0, 0.0]
|
|
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
|
|
|
Scenario 2: Fill to capacity
|
|
update pos 3: [1.0, 2.0, 3.0, 4.0]
|
|
"""
|
|
# Scenario 1: Fill up to near capacity
|
|
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
|
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
|
static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None)
|
|
static_cache.update(
|
|
key_states=torch.tensor(3.0)[None, None, None, None],
|
|
value_states=torch.tensor(3.0)[None, None, None, None],
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.tensor([2])},
|
|
)
|
|
self.assertEqual(
|
|
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
|
|
)
|
|
|
|
# Scenario 2: Fill to capacity
|
|
static_cache.update(
|
|
key_states=torch.tensor(4.0)[None, None, None, None],
|
|
value_states=torch.tensor(4.0)[None, None, None, None],
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.tensor([3])},
|
|
)
|
|
self.assertEqual(
|
|
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
|
|
)
|
|
|
|
def test_sliding_window_cache(self):
|
|
"""Test SlidingWindowCache with manually prefilled states and hardcoded assertions.
|
|
|
|
Scenario 1: Update within window, no slide yet
|
|
prefill: [1.0, 2.0, 0.0, 0.0]
|
|
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
|
|
|
Scenario 2: Update causing slide
|
|
prefill: [1.0, 2.0, 3.0, 4.0]
|
|
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
|
|
|
|
Scenario 3: Long prompt handling (prompt_len > window_size)
|
|
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
|
"""
|
|
# Scenario 1: Update within window, no slide yet
|
|
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
|
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
|
sliding_cache.update(
|
|
key_states=prefill,
|
|
value_states=prefill,
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
|
)
|
|
sliding_cache.update(
|
|
key_states=torch.tensor(3.0)[None, None, None, None],
|
|
value_states=torch.tensor(3.0)[None, None, None, None],
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
|
)
|
|
self.assertEqual(
|
|
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
|
[1.0, 2.0, 3.0, 0.0],
|
|
"SlidingWindowCache Scenario 1 failed",
|
|
)
|
|
|
|
# Scenario 2: Update causing slide
|
|
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
|
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
|
sliding_cache.update(
|
|
key_states=prefill,
|
|
value_states=prefill,
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
|
)
|
|
sliding_cache.update(
|
|
key_states=torch.tensor(5.0)[None, None, None, None],
|
|
value_states=torch.tensor(5.0)[None, None, None, None],
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
|
)
|
|
self.assertEqual(
|
|
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
|
[2.0, 3.0, 4.0, 5.0],
|
|
"SlidingWindowCache Scenario 2 failed",
|
|
)
|
|
|
|
# Scenario 3: Long prompt handling
|
|
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
|
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
|
sliding_cache.update(
|
|
key_states=long_prefill,
|
|
value_states=long_prefill,
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
|
)
|
|
self.assertEqual(
|
|
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
|
[3.0, 4.0, 5.0, 6.0],
|
|
"SlidingWindowCache Scenario 3 failed",
|
|
)
|
|
|
|
def test_hybrid_cache_static_mode(self):
|
|
"""Test HybridCache in static mode with hardcoded assertions.
|
|
|
|
Scenario 1: Static layer behavior
|
|
prefill: [1.0, 2.0, 0.0, 0.0]
|
|
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
|
|
|
Scenario 2: Fill to capacity
|
|
update pos 3: [1.0, 2.0, 3.0, 4.0]
|
|
"""
|
|
config = copy.deepcopy(self.config)
|
|
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
|
|
|
|
# Scenario 1
|
|
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
|
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
|
hybrid_cache_static_mode.update(
|
|
key_states=prefill,
|
|
value_states=prefill,
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.arange(4)},
|
|
)
|
|
hybrid_cache_static_mode.update(
|
|
key_states=torch.tensor(3.0)[None, None, None, None],
|
|
value_states=torch.tensor(3.0)[None, None, None, None],
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.tensor([2])},
|
|
)
|
|
self.assertEqual(
|
|
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
|
[1.0, 2.0, 3.0, 0.0],
|
|
"HybridCache Static Scenario 1 failed",
|
|
)
|
|
|
|
# Scenario 2
|
|
hybrid_cache_static_mode.update(
|
|
key_states=torch.tensor(4.0)[None, None, None, None],
|
|
value_states=torch.tensor(4.0)[None, None, None, None],
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.tensor([3])},
|
|
)
|
|
self.assertEqual(
|
|
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
|
[1.0, 2.0, 3.0, 4.0],
|
|
"HybridCache Static Scenario 2 failed",
|
|
)
|
|
|
|
def test_hybrid_cache_sliding_mode(self):
|
|
"""Test HybridCache in sliding mode with hardcoded assertions.
|
|
|
|
Scenario 1: Update within window, no slide yet
|
|
prefill: [1.0, 2.0, 0.0, 0.0]
|
|
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
|
|
|
Scenario 2: Update causing first slide
|
|
prefill: [1.0, 2.0, 3.0, 4.0]
|
|
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
|
|
|
|
Scenario 3: Update causing subsequent slide
|
|
update pos 5: [3.0, 4.0, 5.0, 6.0] (shift continues)
|
|
|
|
Scenario 4: Long prompt handling (prompt_len > window_size)
|
|
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
|
"""
|
|
# Scenario 1: Update within window, no slide yet
|
|
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
|
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
|
hybrid_cache.update(
|
|
key_states=prefill,
|
|
value_states=prefill,
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
|
)
|
|
hybrid_cache.update(
|
|
key_states=torch.tensor(3.0)[None, None, None, None],
|
|
value_states=torch.tensor(3.0)[None, None, None, None],
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
|
)
|
|
self.assertEqual(
|
|
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
|
[1.0, 2.0, 3.0, 0.0],
|
|
"HybridCache Sliding Scenario 1 failed",
|
|
)
|
|
|
|
# Scenario 2: Update causing first slide
|
|
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
|
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
|
hybrid_cache.update(
|
|
key_states=prefill,
|
|
value_states=prefill,
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
|
)
|
|
hybrid_cache.update(
|
|
key_states=torch.tensor(5.0)[None, None, None, None],
|
|
value_states=torch.tensor(5.0)[None, None, None, None],
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
|
)
|
|
self.assertEqual(
|
|
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
|
[2.0, 3.0, 4.0, 5.0],
|
|
"HybridCache Sliding Scenario 2 failed",
|
|
)
|
|
|
|
# Scenario 3: Update causing subsequent slide
|
|
hybrid_cache.update(
|
|
key_states=torch.tensor(6.0)[None, None, None, None],
|
|
value_states=torch.tensor(6.0)[None, None, None, None],
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
|
|
)
|
|
self.assertEqual(
|
|
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
|
[3.0, 4.0, 5.0, 6.0],
|
|
"HybridCache Sliding Scenario 3 failed",
|
|
)
|
|
|
|
# Scenario 4: Long prompt handling
|
|
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
|
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
|
hybrid_cache.update(
|
|
key_states=long_prefill,
|
|
value_states=long_prefill,
|
|
layer_idx=0,
|
|
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
|
)
|
|
self.assertEqual(
|
|
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
|
[3.0, 4.0, 5.0, 6.0],
|
|
"HybridCache Sliding Scenario 4 failed",
|
|
)
|