mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
New cache tests and refactored Hybrid Cache (#37972)
This commit is contained in:
parent
183fb3637c
commit
d34e21e7dd
@ -21,6 +21,104 @@ if is_hqq_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Utility functions for static/sliding cache update logic
|
||||||
|
def _static_cache_update(
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
cache_position: Optional[torch.LongTensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Updates the static cache tensors in place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k_cache (`torch.Tensor`): The key cache tensor to update.
|
||||||
|
v_cache (`torch.Tensor`): The value cache tensor to update.
|
||||||
|
key_states (`torch.Tensor`): The new key states to add.
|
||||||
|
value_states (`torch.Tensor`): The new value states to add.
|
||||||
|
cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted.
|
||||||
|
If None, the entire cache is overwritten (prefill).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place).
|
||||||
|
"""
|
||||||
|
if cache_position is None:
|
||||||
|
# Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
|
||||||
|
k_cache.copy_(key_states)
|
||||||
|
v_cache.copy_(value_states)
|
||||||
|
else:
|
||||||
|
# Generation phase. Update specific positions.
|
||||||
|
# Use index_copy_ for in-place update (compile-friendly).
|
||||||
|
try:
|
||||||
|
k_cache.index_copy_(2, cache_position, key_states)
|
||||||
|
v_cache.index_copy_(2, cache_position, value_states)
|
||||||
|
except NotImplementedError:
|
||||||
|
# Fallback for devices like MPS where index_copy_ might not be supported.
|
||||||
|
k_cache[:, :, cache_position] = key_states
|
||||||
|
v_cache[:, :, cache_position] = value_states
|
||||||
|
return k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
|
def _sliding_cache_update(
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
cache_position: torch.LongTensor,
|
||||||
|
max_cache_len: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Updates the sliding window cache tensors, returning the potentially modified tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k_cache (`torch.Tensor`): The key cache tensor to update.
|
||||||
|
v_cache (`torch.Tensor`): The value cache tensor to update.
|
||||||
|
key_states (`torch.Tensor`): The new key states to add.
|
||||||
|
value_states (`torch.Tensor`): The new value states to add.
|
||||||
|
cache_position (`torch.LongTensor`): The position indices where the new states should be inserted.
|
||||||
|
max_cache_len (`int`): The maximum length of the sliding window cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update.
|
||||||
|
For prefill > window, these are the full input states.
|
||||||
|
Otherwise, they are the updated cache tensors.
|
||||||
|
"""
|
||||||
|
# Handle prefill phase when prompt length > sliding_window_size
|
||||||
|
if cache_position.shape[0] > max_cache_len:
|
||||||
|
new_k = key_states[:, :, -max_cache_len:, :]
|
||||||
|
new_v = value_states[:, :, -max_cache_len:, :]
|
||||||
|
k_cache.copy_(new_k)
|
||||||
|
v_cache.copy_(new_v)
|
||||||
|
return key_states, value_states
|
||||||
|
|
||||||
|
# Sliding window logic for generation phase or prefill < window
|
||||||
|
slicing = torch.arange(max_cache_len, device=value_states.device)
|
||||||
|
current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
|
||||||
|
to_shift = current_seq_len > max_cache_len
|
||||||
|
indices = (slicing + to_shift.sum()) % max_cache_len
|
||||||
|
|
||||||
|
k_out_shifted = k_cache[:, :, indices]
|
||||||
|
v_out_shifted = v_cache[:, :, indices]
|
||||||
|
|
||||||
|
# Clamp cache_position to determine the *target index* within the shifted cache view
|
||||||
|
update_position = cache_position.clamp(min=0, max=max_cache_len - 1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
k_out_updated = k_out_shifted.index_copy(2, update_position, key_states)
|
||||||
|
v_out_updated = v_out_shifted.index_copy(2, update_position, value_states)
|
||||||
|
except NotImplementedError:
|
||||||
|
# Fallback for MPS: clone and modify the clone
|
||||||
|
k_out_updated = k_out_shifted.clone()
|
||||||
|
v_out_updated = v_out_shifted.clone()
|
||||||
|
k_out_updated[:, :, update_position] = key_states
|
||||||
|
v_out_updated[:, :, update_position] = value_states
|
||||||
|
|
||||||
|
k_cache.copy_(k_out_updated)
|
||||||
|
v_cache.copy_(v_out_updated)
|
||||||
|
return k_out_updated, v_out_updated
|
||||||
|
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
"""
|
"""
|
||||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||||
@ -1264,28 +1362,16 @@ class StaticCache(Cache):
|
|||||||
"""
|
"""
|
||||||
if cache_kwargs is None:
|
if cache_kwargs is None:
|
||||||
cache_kwargs = {}
|
cache_kwargs = {}
|
||||||
cache_position = cache_kwargs.get("cache_position")
|
|
||||||
k_out = self.key_cache[layer_idx]
|
|
||||||
v_out = self.value_cache[layer_idx]
|
|
||||||
key_states = key_states.to(k_out.dtype)
|
|
||||||
value_states = value_states.to(v_out.dtype)
|
|
||||||
|
|
||||||
if cache_position is None:
|
key_states = key_states.to(self.key_cache[layer_idx].dtype)
|
||||||
k_out.copy_(key_states)
|
value_states = value_states.to(self.value_cache[layer_idx].dtype)
|
||||||
v_out.copy_(value_states)
|
return _static_cache_update(
|
||||||
else:
|
self.key_cache[layer_idx],
|
||||||
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
|
self.value_cache[layer_idx],
|
||||||
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
|
key_states,
|
||||||
# operation, that avoids copies and uses less memory.
|
value_states,
|
||||||
try:
|
cache_kwargs.get("cache_position"),
|
||||||
k_out.index_copy_(2, cache_position, key_states)
|
)
|
||||||
v_out.index_copy_(2, cache_position, value_states)
|
|
||||||
except NotImplementedError:
|
|
||||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
|
||||||
k_out[:, :, cache_position] = key_states
|
|
||||||
v_out[:, :, cache_position] = value_states
|
|
||||||
|
|
||||||
return k_out, v_out
|
|
||||||
|
|
||||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
"""Returns the sequence length of the cached states that were seen by the model."""
|
"""Returns the sequence length of the cached states that were seen by the model."""
|
||||||
@ -1314,7 +1400,7 @@ class SlidingWindowCache(StaticCache):
|
|||||||
|
|
||||||
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
|
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
|
||||||
|
|
||||||
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
|
indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window
|
||||||
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||||
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
|
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
|
||||||
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
||||||
@ -1398,46 +1484,21 @@ class SlidingWindowCache(StaticCache):
|
|||||||
if cache_kwargs is None:
|
if cache_kwargs is None:
|
||||||
cache_kwargs = {}
|
cache_kwargs = {}
|
||||||
cache_position = cache_kwargs.get("cache_position")
|
cache_position = cache_kwargs.get("cache_position")
|
||||||
k_out = self.key_cache[layer_idx]
|
|
||||||
v_out = self.value_cache[layer_idx]
|
|
||||||
key_states = key_states.to(k_out.dtype)
|
|
||||||
value_states = value_states.to(v_out.dtype)
|
|
||||||
|
|
||||||
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
|
if cache_position is None:
|
||||||
if cache_position.shape[0] >= self.max_cache_len:
|
raise ValueError("`cache_position` must be provided for SlidingWindowCache.")
|
||||||
k_out = key_states[:, :, -self.max_cache_len :, :]
|
|
||||||
v_out = value_states[:, :, -self.max_cache_len :, :]
|
|
||||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
|
||||||
self.key_cache[layer_idx] += k_out
|
|
||||||
self.value_cache[layer_idx] += v_out
|
|
||||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
|
||||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
|
||||||
return key_states, value_states
|
|
||||||
|
|
||||||
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
key_states = key_states.to(self.key_cache[layer_idx].dtype)
|
||||||
to_shift = cache_position > self.max_cache_len - 1
|
value_states = value_states.to(self.value_cache[layer_idx].dtype)
|
||||||
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
|
|
||||||
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
|
|
||||||
|
|
||||||
k_out = k_out[:, :, indices]
|
return _sliding_cache_update(
|
||||||
v_out = v_out[:, :, indices]
|
self.key_cache[layer_idx],
|
||||||
|
self.value_cache[layer_idx],
|
||||||
try:
|
key_states,
|
||||||
k_out.index_copy_(2, cache_position, key_states)
|
value_states,
|
||||||
v_out.index_copy_(2, cache_position, value_states)
|
cache_position,
|
||||||
except NotImplementedError:
|
self.max_cache_len,
|
||||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
)
|
||||||
k_out[:, :, cache_position] = key_states
|
|
||||||
v_out[:, :, cache_position] = value_states
|
|
||||||
|
|
||||||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
|
||||||
self.key_cache[layer_idx].zero_()
|
|
||||||
self.value_cache[layer_idx].zero_()
|
|
||||||
|
|
||||||
self.key_cache[layer_idx] += k_out
|
|
||||||
self.value_cache[layer_idx] += v_out
|
|
||||||
|
|
||||||
return k_out, v_out
|
|
||||||
|
|
||||||
def get_max_cache_shape(self) -> Optional[int]:
|
def get_max_cache_shape(self) -> Optional[int]:
|
||||||
return self.max_cache_len
|
return self.max_cache_len
|
||||||
@ -1680,12 +1741,13 @@ class HybridCache(Cache):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
"Setting `cache_implementation` to 'hybrid' requires the model config supporting "
|
||||||
"sliding window attention, please check if there is a `sliding_window` field in the model "
|
"sliding window attention, please check if there is a `sliding_window` field in the model "
|
||||||
"config and it's not set to None."
|
"config and it's not set to None."
|
||||||
)
|
)
|
||||||
self.max_cache_len = max_cache_len
|
self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings
|
||||||
self._sliding_window_max_len = min(config.sliding_window, max_cache_len)
|
# Sliding layers can't be larger than the overall max cache len
|
||||||
|
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||||
self.head_dim = (
|
self.head_dim = (
|
||||||
@ -1694,22 +1756,17 @@ class HybridCache(Cache):
|
|||||||
|
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
config.num_attention_heads
|
||||||
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
|
else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
||||||
self.is_sliding = torch.tensor(
|
self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
|
||||||
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
|
||||||
)
|
|
||||||
self.key_cache: List[torch.Tensor] = []
|
self.key_cache: List[torch.Tensor] = []
|
||||||
self.value_cache: List[torch.Tensor] = []
|
self.value_cache: List[torch.Tensor] = []
|
||||||
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
|
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||||
sliding_cache_shape = (
|
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
|
||||||
self.max_batch_size,
|
|
||||||
self.num_key_value_heads,
|
|
||||||
self._sliding_window_max_len,
|
|
||||||
self.head_dim,
|
|
||||||
)
|
|
||||||
device = torch.device(device) if device is not None else None
|
device = torch.device(device) if device is not None else None
|
||||||
for i in range(config.num_hidden_layers):
|
for i in range(config.num_hidden_layers):
|
||||||
if layer_device_map is not None:
|
if layer_device_map is not None:
|
||||||
@ -1718,7 +1775,7 @@ class HybridCache(Cache):
|
|||||||
layer_device = device
|
layer_device = device
|
||||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||||
# breaks when updating the cache.
|
# breaks when updating the cache.
|
||||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||||
@ -1726,42 +1783,6 @@ class HybridCache(Cache):
|
|||||||
self.key_cache.append(new_layer_key_cache)
|
self.key_cache.append(new_layer_key_cache)
|
||||||
self.value_cache.append(new_layer_value_cache)
|
self.value_cache.append(new_layer_value_cache)
|
||||||
|
|
||||||
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
|
||||||
if cache_position.shape[0] >= max_cache_len:
|
|
||||||
k_out = key_states[:, :, -max_cache_len:, :]
|
|
||||||
v_out = value_states[:, :, -max_cache_len:, :]
|
|
||||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
|
||||||
self.key_cache[layer_idx] += k_out
|
|
||||||
self.value_cache[layer_idx] += v_out
|
|
||||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
|
||||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
|
||||||
return key_states, value_states
|
|
||||||
|
|
||||||
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
|
||||||
to_shift = cache_position > max_cache_len - 1
|
|
||||||
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
|
||||||
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
|
|
||||||
k_out = k_out[:, :, indices]
|
|
||||||
v_out = v_out[:, :, indices]
|
|
||||||
|
|
||||||
k_out[:, :, cache_position] = key_states
|
|
||||||
v_out[:, :, cache_position] = value_states
|
|
||||||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
|
||||||
self.key_cache[layer_idx].zero_()
|
|
||||||
self.value_cache[layer_idx].zero_()
|
|
||||||
|
|
||||||
self.key_cache[layer_idx] += k_out
|
|
||||||
self.value_cache[layer_idx] += v_out
|
|
||||||
return k_out, v_out
|
|
||||||
|
|
||||||
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
|
||||||
k_out[:, :, cache_position] = key_states
|
|
||||||
v_out[:, :, cache_position] = value_states
|
|
||||||
|
|
||||||
self.key_cache[layer_idx] = k_out
|
|
||||||
self.value_cache[layer_idx] = v_out
|
|
||||||
return k_out, v_out
|
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
key_states: torch.Tensor,
|
key_states: torch.Tensor,
|
||||||
@ -1772,7 +1793,10 @@ class HybridCache(Cache):
|
|||||||
if cache_kwargs is None:
|
if cache_kwargs is None:
|
||||||
cache_kwargs = {}
|
cache_kwargs = {}
|
||||||
cache_position = cache_kwargs.get("cache_position")
|
cache_position = cache_kwargs.get("cache_position")
|
||||||
sliding_window = cache_kwargs.get("sliding_window")
|
if cache_position is None:
|
||||||
|
raise ValueError("`cache_position` must be provided for HybridCache.")
|
||||||
|
|
||||||
|
is_sliding_layer = self.is_sliding_list[layer_idx]
|
||||||
|
|
||||||
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
|
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
|
||||||
# when the cache is initialized in the forward pass (e.g. Gemma2)
|
# when the cache is initialized in the forward pass (e.g. Gemma2)
|
||||||
@ -1781,25 +1805,22 @@ class HybridCache(Cache):
|
|||||||
if self.value_cache[layer_idx].device != value_states.device:
|
if self.value_cache[layer_idx].device != value_states.device:
|
||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
||||||
|
|
||||||
k_out = self.key_cache[layer_idx]
|
k_cache = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache[layer_idx]
|
v_cache = self.value_cache[layer_idx]
|
||||||
key_states = key_states.to(k_out.dtype)
|
key_states = key_states.to(k_cache.dtype)
|
||||||
value_states = value_states.to(v_out.dtype)
|
value_states = value_states.to(v_cache.dtype)
|
||||||
|
|
||||||
if sliding_window:
|
if is_sliding_layer:
|
||||||
update_fn = self._sliding_update
|
return _sliding_cache_update(
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cache_position,
|
||||||
|
k_cache.shape[2], # Use actual cache dim as max cache len
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
update_fn = self._static_update
|
return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position)
|
||||||
|
|
||||||
return update_fn(
|
|
||||||
cache_position,
|
|
||||||
layer_idx,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
k_out,
|
|
||||||
v_out,
|
|
||||||
k_out.shape[2],
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_max_cache_shape(self) -> Optional[int]:
|
def get_max_cache_shape(self) -> Optional[int]:
|
||||||
return self.max_cache_len
|
return self.max_cache_len
|
||||||
@ -2033,7 +2054,7 @@ class OffloadedHybridCache(HybridChunkedCache):
|
|||||||
|
|
||||||
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
||||||
# track of the original device of each layer
|
# track of the original device of each layer
|
||||||
unique_devices = set(layer_device_map.values())
|
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
|
||||||
if len(unique_devices) > 1:
|
if len(unique_devices) > 1:
|
||||||
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")
|
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")
|
||||||
|
|
||||||
@ -2292,7 +2313,7 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
|
|
||||||
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
||||||
# track of the original device of each layer
|
# track of the original device of each layer
|
||||||
unique_devices = set(layer_device_map.values())
|
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
|
||||||
if len(unique_devices) > 1:
|
if len(unique_devices) > 1:
|
||||||
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")
|
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")
|
||||||
|
|
||||||
@ -2369,6 +2390,9 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
A tuple containing the updated key and value states.
|
A tuple containing the updated key and value states.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
key_states = key_states.to(self.key_cache[layer_idx].dtype)
|
||||||
|
value_states = value_states.to(self.value_cache[layer_idx].dtype)
|
||||||
|
|
||||||
if layer_idx == 0:
|
if layer_idx == 0:
|
||||||
# Update seen tokens.
|
# Update seen tokens.
|
||||||
# TODO(gante): Remove this.
|
# TODO(gante): Remove this.
|
||||||
|
@ -46,10 +46,14 @@ if is_torch_available():
|
|||||||
Cache,
|
Cache,
|
||||||
ClvpForCausalLM,
|
ClvpForCausalLM,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
|
Gemma2Config,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
|
HybridCache,
|
||||||
LlamaConfig,
|
LlamaConfig,
|
||||||
|
SlidingWindowCache,
|
||||||
StaticCache,
|
StaticCache,
|
||||||
convert_and_export_with_cache,
|
convert_and_export_with_cache,
|
||||||
|
pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -188,6 +192,21 @@ class CacheTest(unittest.TestCase):
|
|||||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||||
|
|
||||||
|
|
||||||
|
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
|
||||||
|
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
||||||
|
# Installed dependencies
|
||||||
|
if cache_implementation == "quantized" and not is_optimum_quanto_available():
|
||||||
|
test.skipTest("Quanto is not available")
|
||||||
|
# Devices
|
||||||
|
if "offloaded" in cache_implementation:
|
||||||
|
has_accelerator = torch_device is not None and torch_device != "cpu"
|
||||||
|
if not has_accelerator:
|
||||||
|
test.skipTest("Offloaded caches require an accelerator")
|
||||||
|
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
||||||
|
if backend_device_count(torch_device) != 1:
|
||||||
|
test.skipTest("Offloaded static caches require exactly 1 accelerator")
|
||||||
|
|
||||||
|
|
||||||
class CacheIntegrationTest(unittest.TestCase):
|
class CacheIntegrationTest(unittest.TestCase):
|
||||||
"""Fast cache integration tests that share the same small model"""
|
"""Fast cache integration tests that share the same small model"""
|
||||||
|
|
||||||
@ -200,24 +219,10 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
|
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
|
||||||
|
|
||||||
def _skip_on_failed_cache_prerequisites(self, cache_implementation):
|
|
||||||
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
|
||||||
# Installed dependencies
|
|
||||||
if cache_implementation == "quantized" and not is_optimum_quanto_available():
|
|
||||||
self.skipTest("Quanto is not available")
|
|
||||||
# Devices
|
|
||||||
if "offloaded" in cache_implementation:
|
|
||||||
has_accelerator = torch_device is not None and torch_device != "cpu"
|
|
||||||
if not has_accelerator:
|
|
||||||
self.skipTest("Offloaded caches require an accelerator")
|
|
||||||
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
|
||||||
if backend_device_count(torch_device) != 1:
|
|
||||||
self.skipTest("Offloaded static caches require exactly 1 accelerator")
|
|
||||||
|
|
||||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||||
def test_cache_batched(self, cache_implementation):
|
def test_cache_batched(self, cache_implementation):
|
||||||
"""Sanity check: caches' `.update` function expects batched inputs"""
|
"""Sanity check: caches' `.update` function expects batched inputs"""
|
||||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
||||||
|
|
||||||
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||||
|
|
||||||
@ -246,7 +251,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
|
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
|
||||||
(an output sequence contains multiple beam indices).
|
(an output sequence contains multiple beam indices).
|
||||||
"""
|
"""
|
||||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
||||||
if cache_implementation == "offloaded_hybrid_chunked":
|
if cache_implementation == "offloaded_hybrid_chunked":
|
||||||
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
|
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
|
||||||
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
|
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
|
||||||
@ -280,7 +285,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||||
def test_cache_extra_left_padding(self, cache_implementation):
|
def test_cache_extra_left_padding(self, cache_implementation):
|
||||||
"""Tests that adding extra left-padding does not affect the generation with the cache"""
|
"""Tests that adding extra left-padding does not affect the generation with the cache"""
|
||||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
||||||
|
|
||||||
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
|
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
|
||||||
|
|
||||||
@ -552,6 +557,28 @@ class CacheHardIntegrationTest(unittest.TestCase):
|
|||||||
_ = model(**inputs)
|
_ = model(**inputs)
|
||||||
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
|
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||||
|
def test_cache_gptj_model(self, cache_implementation):
|
||||||
|
"""Tests caches with GPT-J model. Regression test for https://github.com/huggingface/transformers/pull/34799"""
|
||||||
|
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
||||||
|
|
||||||
|
model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM"
|
||||||
|
pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16)
|
||||||
|
pipe.model.config.sliding_window = (
|
||||||
|
256 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None
|
||||||
|
)
|
||||||
|
out = pipe(
|
||||||
|
"hello world",
|
||||||
|
cache_implementation=cache_implementation,
|
||||||
|
max_new_tokens=10,
|
||||||
|
do_sample=False,
|
||||||
|
disable_compile=True,
|
||||||
|
return_tensors=True,
|
||||||
|
)[0]["generated_token_ids"][-10:]
|
||||||
|
EXPECTED_OUTPUT = [879, 175, 39, 141, 1000, 975, 951, 991, 683, 441]
|
||||||
|
self.assertListEqual(out, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class CacheExportIntegrationTest(unittest.TestCase):
|
class CacheExportIntegrationTest(unittest.TestCase):
|
||||||
@ -721,3 +748,276 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
dynamic_shapes=dynamic_shapes,
|
dynamic_shapes=dynamic_shapes,
|
||||||
strict=False,
|
strict=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticCacheTest(unittest.TestCase):
|
||||||
|
"""Tests cache behavior with simple dummy data."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up common configuration and cache instances for all tests."""
|
||||||
|
self.window_size = 4
|
||||||
|
self.max_cache_len = 4
|
||||||
|
self.config = Gemma2Config(
|
||||||
|
num_hidden_layers=1,
|
||||||
|
num_key_value_heads=1,
|
||||||
|
num_attention_heads=1,
|
||||||
|
head_dim=1,
|
||||||
|
hidden_size=1,
|
||||||
|
sliding_window=self.window_size,
|
||||||
|
sliding_window_pattern=2, # Default pattern for hybrid sliding
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_static_cache_out_of_bounds(self):
|
||||||
|
"""Test StaticCache raises IndexError for out-of-bounds positions."""
|
||||||
|
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||||
|
pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len
|
||||||
|
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
static_cache.update(
|
||||||
|
key_states=torch.tensor([[[[1.0]]]]),
|
||||||
|
value_states=torch.tensor([[[[1.0]]]]),
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": pos_out_of_bounds},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_static_cache(self):
|
||||||
|
"""Test StaticCache with manually prefilled states and hardcoded assertions.
|
||||||
|
|
||||||
|
Scenario 1: Fill up to near capacity
|
||||||
|
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||||
|
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
||||||
|
|
||||||
|
Scenario 2: Fill to capacity
|
||||||
|
update pos 3: [1.0, 2.0, 3.0, 4.0]
|
||||||
|
"""
|
||||||
|
# Scenario 1: Fill up to near capacity
|
||||||
|
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||||
|
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||||
|
static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None)
|
||||||
|
static_cache.update(
|
||||||
|
key_states=torch.tensor(3.0)[None, None, None, None],
|
||||||
|
value_states=torch.tensor(3.0)[None, None, None, None],
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scenario 2: Fill to capacity
|
||||||
|
static_cache.update(
|
||||||
|
key_states=torch.tensor(4.0)[None, None, None, None],
|
||||||
|
value_states=torch.tensor(4.0)[None, None, None, None],
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sliding_window_cache(self):
|
||||||
|
"""Test SlidingWindowCache with manually prefilled states and hardcoded assertions.
|
||||||
|
|
||||||
|
Scenario 1: Update within window, no slide yet
|
||||||
|
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||||
|
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
||||||
|
|
||||||
|
Scenario 2: Update causing slide
|
||||||
|
prefill: [1.0, 2.0, 3.0, 4.0]
|
||||||
|
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
|
||||||
|
|
||||||
|
Scenario 3: Long prompt handling (prompt_len > window_size)
|
||||||
|
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||||
|
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||||
|
"""
|
||||||
|
# Scenario 1: Update within window, no slide yet
|
||||||
|
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||||
|
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||||
|
sliding_cache.update(
|
||||||
|
key_states=prefill,
|
||||||
|
value_states=prefill,
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
sliding_cache.update(
|
||||||
|
key_states=torch.tensor(3.0)[None, None, None, None],
|
||||||
|
value_states=torch.tensor(3.0)[None, None, None, None],
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||||
|
[1.0, 2.0, 3.0, 0.0],
|
||||||
|
"SlidingWindowCache Scenario 1 failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scenario 2: Update causing slide
|
||||||
|
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||||
|
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||||
|
sliding_cache.update(
|
||||||
|
key_states=prefill,
|
||||||
|
value_states=prefill,
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
sliding_cache.update(
|
||||||
|
key_states=torch.tensor(5.0)[None, None, None, None],
|
||||||
|
value_states=torch.tensor(5.0)[None, None, None, None],
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||||
|
[2.0, 3.0, 4.0, 5.0],
|
||||||
|
"SlidingWindowCache Scenario 2 failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scenario 3: Long prompt handling
|
||||||
|
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||||
|
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||||
|
sliding_cache.update(
|
||||||
|
key_states=long_prefill,
|
||||||
|
value_states=long_prefill,
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||||
|
[3.0, 4.0, 5.0, 6.0],
|
||||||
|
"SlidingWindowCache Scenario 3 failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hybrid_cache_static_mode(self):
|
||||||
|
"""Test HybridCache in static mode with hardcoded assertions.
|
||||||
|
|
||||||
|
Scenario 1: Static layer behavior
|
||||||
|
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||||
|
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
||||||
|
|
||||||
|
Scenario 2: Fill to capacity
|
||||||
|
update pos 3: [1.0, 2.0, 3.0, 4.0]
|
||||||
|
"""
|
||||||
|
config = copy.deepcopy(self.config)
|
||||||
|
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
|
||||||
|
|
||||||
|
# Scenario 1
|
||||||
|
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||||
|
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||||
|
hybrid_cache_static_mode.update(
|
||||||
|
key_states=prefill,
|
||||||
|
value_states=prefill,
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.arange(4)},
|
||||||
|
)
|
||||||
|
hybrid_cache_static_mode.update(
|
||||||
|
key_states=torch.tensor(3.0)[None, None, None, None],
|
||||||
|
value_states=torch.tensor(3.0)[None, None, None, None],
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||||
|
[1.0, 2.0, 3.0, 0.0],
|
||||||
|
"HybridCache Static Scenario 1 failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scenario 2
|
||||||
|
hybrid_cache_static_mode.update(
|
||||||
|
key_states=torch.tensor(4.0)[None, None, None, None],
|
||||||
|
value_states=torch.tensor(4.0)[None, None, None, None],
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||||
|
[1.0, 2.0, 3.0, 4.0],
|
||||||
|
"HybridCache Static Scenario 2 failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hybrid_cache_sliding_mode(self):
|
||||||
|
"""Test HybridCache in sliding mode with hardcoded assertions.
|
||||||
|
|
||||||
|
Scenario 1: Update within window, no slide yet
|
||||||
|
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||||
|
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
||||||
|
|
||||||
|
Scenario 2: Update causing first slide
|
||||||
|
prefill: [1.0, 2.0, 3.0, 4.0]
|
||||||
|
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
|
||||||
|
|
||||||
|
Scenario 3: Update causing subsequent slide
|
||||||
|
update pos 5: [3.0, 4.0, 5.0, 6.0] (shift continues)
|
||||||
|
|
||||||
|
Scenario 4: Long prompt handling (prompt_len > window_size)
|
||||||
|
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||||
|
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||||
|
"""
|
||||||
|
# Scenario 1: Update within window, no slide yet
|
||||||
|
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||||
|
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||||
|
hybrid_cache.update(
|
||||||
|
key_states=prefill,
|
||||||
|
value_states=prefill,
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
hybrid_cache.update(
|
||||||
|
key_states=torch.tensor(3.0)[None, None, None, None],
|
||||||
|
value_states=torch.tensor(3.0)[None, None, None, None],
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||||
|
[1.0, 2.0, 3.0, 0.0],
|
||||||
|
"HybridCache Sliding Scenario 1 failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scenario 2: Update causing first slide
|
||||||
|
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||||
|
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||||
|
hybrid_cache.update(
|
||||||
|
key_states=prefill,
|
||||||
|
value_states=prefill,
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
hybrid_cache.update(
|
||||||
|
key_states=torch.tensor(5.0)[None, None, None, None],
|
||||||
|
value_states=torch.tensor(5.0)[None, None, None, None],
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||||
|
[2.0, 3.0, 4.0, 5.0],
|
||||||
|
"HybridCache Sliding Scenario 2 failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scenario 3: Update causing subsequent slide
|
||||||
|
hybrid_cache.update(
|
||||||
|
key_states=torch.tensor(6.0)[None, None, None, None],
|
||||||
|
value_states=torch.tensor(6.0)[None, None, None, None],
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||||
|
[3.0, 4.0, 5.0, 6.0],
|
||||||
|
"HybridCache Sliding Scenario 3 failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scenario 4: Long prompt handling
|
||||||
|
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||||
|
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||||
|
hybrid_cache.update(
|
||||||
|
key_states=long_prefill,
|
||||||
|
value_states=long_prefill,
|
||||||
|
layer_idx=0,
|
||||||
|
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||||
|
[3.0, 4.0, 5.0, 6.0],
|
||||||
|
"HybridCache Sliding Scenario 4 failed",
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user