raushan review, arthur review

This commit is contained in:
Manuel de Prada Corral 2025-07-02 17:15:16 +02:00
parent 26c28af616
commit 16a6624087
48 changed files with 533 additions and 649 deletions

View File

@ -82,24 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s
## Cache storage implementation
The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`].
Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`.
Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated.
In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`.
- `key_cache`: A list of tensors, one for each layer.
- `value_cache`: A list of tensors, one for each layer.
The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token:
When new tokens are processed:
1. For each layer, the new key and value states are concatenated with the existing cache.
```py
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)
```
2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token.
3. The cache maintains a count of seen tokens through `self._seen_tokens`. This is updated when the first layer processes a new token.
Other layers like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added.
The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token.
@ -143,7 +137,7 @@ The legacy format is essentially the same data structure but organized different
- The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`.
- The format is less flexible and doesn't support features like quantization or offloading.
If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format.
If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~Cache.from_legacy_cache`] and [`Cache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format.
```py
import torch

File diff suppressed because it is too large Load Diff

View File

@ -1951,7 +1951,7 @@ class GenerationMixin(ContinuousMixin):
layer_device_map = self._get_layer_device_map_for_cache_init()
cache_kwargs = {
"config": self.config.get_text_config(),
"model_config": self.config.get_text_config(),
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"dtype": cache_dtype,

View File

@ -275,15 +275,15 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
self.model = model
self.static_cache = StaticCache(
config=self.model.config,
model_config=self.model.config,
max_batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
device=self.model.generation_config.cache_config.device,
dtype=self.model.dtype,
)
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
for i in range(len(self.static_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
"""
@ -404,7 +404,7 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
# Initialize the HybridCache
self.cache = HybridCache(
config=self.model.config,
model_config=self.model.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.model.device,
@ -412,9 +412,9 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
)
# Register all key and value cache tensors as buffers
for i in range(len(self.cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False)
for i in range(len(self.cache)):
self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False)
self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False)
def forward(
self,
@ -550,7 +550,7 @@ class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
# Initialize static cache
self.static_cache = StaticCache(
config=self.config,
model_config=self.config,
max_batch_size=batch_size,
max_cache_len=max_static_cache_length,
device="cpu",
@ -558,9 +558,9 @@ class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
)
# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
for i in range(len(self.static_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
# Get outputs from decoder

View File

@ -692,10 +692,10 @@ def create_causal_mask(
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the full layers
if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
layer_idx = past_key_values.is_sliding.index(False)
else:
layer_idx = 0
is_sliding = []
if past_key_values is not None:
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
layer_idx = is_sliding.index(True) if True in is_sliding else 0
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
@ -774,10 +774,10 @@ def create_sliding_window_causal_mask(
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
layer_idx = past_key_values.is_sliding.index(True)
else:
layer_idx = 0
is_sliding = []
if past_key_values is not None:
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
layer_idx = is_sliding.index(True) if True in is_sliding else 0
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
@ -861,10 +861,10 @@ def create_chunked_causal_mask(
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
layer_idx = past_key_values.is_sliding.index(True)
else:
layer_idx = 0
is_sliding = []
if past_key_values is not None:
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
layer_idx = is_sliding.index(True) if True in is_sliding else 0
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx

View File

@ -230,8 +230,8 @@ class BartAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -1293,8 +1293,8 @@ class BigBirdPegasusDecoderAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -207,8 +207,8 @@ class BioGptAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -229,8 +229,8 @@ class BlenderbotAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -213,8 +213,8 @@ class BlenderbotSmallAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -356,8 +356,8 @@ class DiaCrossAttention(nn.Module):
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
else:
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)

View File

@ -182,8 +182,8 @@ class DiaCrossAttention(nn.Module):
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
else:
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)

View File

@ -1332,8 +1332,8 @@ class Gemma3nTextAttention(nn.Module):
else:
indices = cache_position
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices]
value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices]
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)

View File

@ -1774,8 +1774,8 @@ class Gemma3nTextAttention(Gemma3Attention):
else:
indices = cache_position
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices]
value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices]
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)

View File

@ -728,8 +728,9 @@ class GPTJModel(GPTJPreTrainedModel):
# Ensure layer_past is on same device as hidden_states (might not be correct)
if past_key_values is not None:
past_key_values.key_cache = past_key_values.key_cache.to(hidden_states.device)
past_key_values.value_cache = past_key_values.value_cache.to(hidden_states.device)
for layer in past_key_values.layers:
layer.keys = layer.keys.to(hidden_states.device)
layer.values = layer.values.to(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if causal_mask is not None:

View File

@ -484,8 +484,8 @@ class InformerAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
@ -601,8 +601,8 @@ class InformerProbSparseAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -290,8 +290,8 @@ class InformerProbSparseAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -478,8 +478,8 @@ class LongT5Attention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k(current_states)
value_states = self.v(current_states)

View File

@ -294,8 +294,8 @@ class M2M100Attention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -229,8 +229,8 @@ class MarianAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -239,8 +239,8 @@ class MBartAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -105,16 +105,14 @@ class MiniMaxCache(DynamicCache):
if self.linear_cache[layer_idx] != []:
self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
else:
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.layers[layer_idx].batch_repeat_interleave(repeats)
def batch_select_indices(self, indices: torch.Tensor):
for layer_idx in range(len(self)):
if self.linear_cache[layer_idx] != []:
self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
else:
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
self.layers[layer_idx].batch_select_indices(indices)
def crop(self, max_length: int):
raise RuntimeError("MiniMaxCache doesnot support `crop` method")

View File

@ -215,16 +215,14 @@ class MiniMaxCache(DynamicCache):
if self.linear_cache[layer_idx] != []:
self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
else:
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.layers[layer_idx].batch_repeat_interleave(repeats)
def batch_select_indices(self, indices: torch.Tensor):
for layer_idx in range(len(self)):
if self.linear_cache[layer_idx] != []:
self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
else:
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
self.layers[layer_idx].batch_select_indices(indices)
def crop(self, max_length: int):
raise RuntimeError("MiniMaxCache doesnot support `crop` method")

View File

@ -496,8 +496,8 @@ class MllamaTextCrossAttention(nn.Module):
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
past_key_value.layers[self.layer_idx].keys,
past_key_value.layers[self.layer_idx].values,
)
else:
raise ValueError(

View File

@ -235,8 +235,8 @@ class MoonshineAttention(nn.Module):
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
key_states = past_key_value.layers[self.layer_idx].keys
value_states = past_key_value.layers[self.layer_idx].values
else:
key_states = (
self.k_proj(current_states)

View File

@ -331,8 +331,8 @@ class MoonshineAttention(GlmAttention):
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
key_states = past_key_value.layers[self.layer_idx].keys
value_states = past_key_value.layers[self.layer_idx].values
else:
key_states = (
self.k_proj(current_states)

View File

@ -376,8 +376,8 @@ class MT5Attention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k(current_states)
value_states = self.v(current_states)

View File

@ -228,8 +228,8 @@ class PegasusAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -249,8 +249,8 @@ class PegasusXAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -770,8 +770,8 @@ class Pix2StructTextAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.key(current_states)
value_states = self.value(current_states)

View File

@ -425,8 +425,8 @@ class PLBartAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -320,8 +320,8 @@ class Pop2PianoAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k(current_states)
value_states = self.v(current_states)

View File

@ -513,8 +513,8 @@ class SwitchTransformersAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k(current_states)
value_states = self.v(current_states)

View File

@ -501,8 +501,8 @@ class T5Attention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k(current_states)
value_states = self.v(current_states)

View File

@ -352,8 +352,8 @@ class T5GemmaCrossAttention(nn.Module):
past_key_value.is_updated[self.layer_idx] = True
# cross-attention: reuse cached states
else:
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":

View File

@ -308,8 +308,8 @@ class T5GemmaCrossAttention(Gemma2Attention):
past_key_value.is_updated[self.layer_idx] = True
# cross-attention: reuse cached states
else:
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":

View File

@ -394,8 +394,8 @@ class TimeSeriesTransformerAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)

View File

@ -599,8 +599,8 @@ class UdopAttention(nn.Module):
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k(current_states)
value_states = self.v(current_states)

View File

@ -285,8 +285,8 @@ class UMT5Attention(nn.Module):
current_states = encoder_hidden_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k(current_states)
value_states = self.v(current_states)

View File

@ -1140,8 +1140,8 @@ class WhisperGenerationMixin(GenerationMixin):
for layer_idx in range(self.config.decoder_layers):
layer_past_key_values = []
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
for v in [cache_cls.key_cache, cache_cls.value_cache]:
layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
for v in [cache_cls.layers[layer_idx].keys, cache_cls.layers[layer_idx].values]:
layer_past_key_values.append(v[batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)
else:

View File

@ -329,8 +329,8 @@ class WhisperAttention(nn.Module):
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
key_states = past_key_value.layers[self.layer_idx].keys
value_states = past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)

View File

@ -1595,9 +1595,7 @@ class GenerationTesterMixin:
# 3. Check cache shapes
# 3.1. Encoder-Decoder checks
if config.is_encoder_decoder:
num_cache_decoder_layers = (
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@ -1605,30 +1603,30 @@ class GenerationTesterMixin:
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
self_attention_layer_keys = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
)
self_attention_layer_value_cache = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
self_attention_layer_values = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
)
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
# Cross attention (ignore 3rd dim, see default shape preparation)
cross_attention_layer_key_cache = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
cross_attention_layer_keys = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
)
cross_attention_layer_value_cache = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
cross_attention_layer_values = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
)
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@ -1636,10 +1634,10 @@ class GenerationTesterMixin:
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
@pytest.mark.generate
@parameterized.expand([("greedy", 1), ("beam search", 2)])
@ -1798,8 +1796,8 @@ class GenerationTesterMixin:
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
self.assertIsInstance(outputs.past_key_values, StaticCache)
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
self.assertEqual(len(outputs.past_key_values), num_hidden_layers)
self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape)
@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
@ -2029,8 +2027,8 @@ class GenerationTesterMixin:
num_hidden_layers = text_config.num_hidden_layers
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape)
# Check 2: The outputs must be similar to the case with dynamic cache
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
@ -2612,12 +2610,12 @@ class GenerationTesterMixin:
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
[layer.keys.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
[layer.values.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
@ -3976,13 +3974,13 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(isinstance(results.past_key_values, StaticCache))
# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
keys_0 = results.past_key_values.layers[0].keys
values_0 = results.past_key_values.layers[0].values
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
keys_1 = results.past_key_values.layers[1].keys
values_1 = results.past_key_values.layers[1].values
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
@pytest.mark.generate
@require_torch_multi_accelerator
@ -4054,13 +4052,13 @@ class GenerationIntegrationTests(unittest.TestCase):
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
keys_0 = results.past_key_values.layers[0].keys
values_0 = results.past_key_values.layers[0].values
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
keys_1 = results.past_key_values.layers[1].keys
values_1 = results.past_key_values.layers[1].values
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
@slow
def test_padding_input_contrastive_search_gpt2(self):

View File

@ -440,13 +440,11 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# difference: last dim
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
v_embed_dim = config.v_head_dim
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
# build the full cache shapes
num_hidden_layers = config.num_hidden_layers
all_cache_shapes = [
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
]
all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)]
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
@require_torch_large_accelerator

View File

@ -399,12 +399,12 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
[layer.keys.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
[layer.values.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
def _check_scores(self, batch_size, scores, generated_length, config):

View File

@ -235,8 +235,8 @@ class GPTNeoXModelTester:
"""Deep copy a DynamicCache to reuse the same one multiple times."""
new_cache = cache
for i in range(len(cache)):
new_cache.key_cache[i] = cache.key_cache[i].clone()
new_cache.value_cache[i] = cache.value_cache[i].clone()
new_cache.layers[i].keys = cache.layers[i].keys.clone()
new_cache.layers[i].values = cache.layers[i].values.clone()
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
# We need to run both on a copy of the cache, otherwise it is modified in-place

View File

@ -272,7 +272,7 @@ class T5GemmaModelTester:
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertIsNotNone(decoder_past)
self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers)
self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers)
self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers)
def check_prepare_lm_labels_via_shift_left(
self,
@ -1069,9 +1069,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# 3. Check cache shapes
# 3.1. Encoder-Decoder checks
if config.is_encoder_decoder:
num_cache_decoder_layers = (
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@ -1079,30 +1077,30 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
self_attention_layer_keys = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
)
self_attention_layer_value_cache = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
self_attention_layer_values = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
)
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
# Cross attention (ignore 3rd dim, see default shape preparation)
cross_attention_layer_key_cache = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
cross_attention_layer_keys = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
)
cross_attention_layer_value_cache = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
cross_attention_layer_values = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
)
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@ -1110,10 +1108,10 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
@unittest.skip("Mismatch issue doesn't exist in T5Gemma.")
def test_load_with_mismatched_shapes(self):

View File

@ -171,7 +171,9 @@ class CacheTest(unittest.TestCase):
return random_keys, random_values
mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
mha_static_cache = StaticCache(
model_config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device
)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
@ -179,7 +181,9 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
gqa_static_cache = StaticCache(
model_config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device
)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
@ -187,7 +191,9 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
mqa_static_cache = StaticCache(
model_config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device
)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
@ -323,7 +329,7 @@ class CacheIntegrationTest(unittest.TestCase):
)
self.assertIsInstance(gen_out.past_key_values, QuantizedCache)
processor = gen_out.past_key_values.processor
processor = gen_out.past_key_values.cache_processor
if backend == "quanto":
self.assertIsInstance(processor, QuantoQuantizedCacheProcessor)
elif backend == "hqq":
@ -332,12 +338,10 @@ class CacheIntegrationTest(unittest.TestCase):
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
self.assertListEqual(decoded, expected_generation)
self.assertTrue(len(processor._quantized_key_cache) > 0)
self.assertTrue(len(processor._quantized_keys) > 0)
# Check that something is actually quantized
has_been_quantized = any(
(q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_key_cache
)
has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys)
self.assertTrue(has_been_quantized)
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
@ -654,7 +658,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
past_key_values=DynamicCache(),
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
@ -675,11 +679,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
use_cache=True,
)
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers):
self.assertTrue(torch.allclose(l1.keys, l2.keys))
self.assertTrue(torch.allclose(l1.values, l2.values))
def test_dynamic_cache_exportability_multiple_run(self):
# When exporting with DynamicCache, you should export two graphs:
@ -703,7 +705,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
past_key_values=DynamicCache(),
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
@ -728,9 +730,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
shapes = torch.export.ShapesCollection()
dyn = torch.export.Dim("seq", max=512)
for ix in range(len(past_key_values.key_cache)):
shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None)
shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None)
for ix in range(len(past_key_values)):
shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None)
shapes[past_key_values.layers[ix].values] = (None, None, dyn, None)
ep_second = torch.export.export(
model,
@ -771,11 +773,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
use_cache=True,
)
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers):
self.assertTrue(torch.allclose(l1.keys, l2.keys))
self.assertTrue(torch.allclose(l1.values, l2.values))
def test_static_cache_exportability(self):
"""
@ -922,7 +922,7 @@ class SyntheticCacheTest(unittest.TestCase):
def test_static_cache_out_of_bounds(self):
"""Test StaticCache raises IndexError for out-of-bounds positions."""
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len
with self.assertRaises(IndexError):
@ -944,7 +944,7 @@ class SyntheticCacheTest(unittest.TestCase):
update pos 3: [1.0, 2.0, 3.0, 4.0]
"""
# Scenario 1: Fill up to near capacity
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None)
static_cache.update(
@ -954,7 +954,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
)
# Scenario 2: Fill to capacity
@ -965,7 +965,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
)
def test_sliding_window_cache(self):
@ -984,7 +984,9 @@ class SyntheticCacheTest(unittest.TestCase):
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
# Scenario 1: Update within window, no slide yet
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
sliding_cache = SlidingWindowCache(
model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len
)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
@ -999,13 +1001,15 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"SlidingWindowCache Scenario 1 failed",
)
# Scenario 2: Update causing slide
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
sliding_cache = SlidingWindowCache(
model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len
)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
@ -1020,13 +1024,15 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"SlidingWindowCache Scenario 2 failed",
)
# Scenario 3: Long prompt handling
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
sliding_cache = SlidingWindowCache(
model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len
)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
sliding_cache.update(
key_states=long_prefill,
@ -1035,7 +1041,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"SlidingWindowCache Scenario 3 failed",
)
@ -1054,7 +1060,7 @@ class SyntheticCacheTest(unittest.TestCase):
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
# Scenario 1
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache_static_mode = HybridCache(model_config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
hybrid_cache_static_mode.update(
key_states=prefill,
@ -1069,7 +1075,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Static Scenario 1 failed",
)
@ -1082,7 +1088,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 4.0],
"HybridCache Static Scenario 2 failed",
)
@ -1106,7 +1112,7 @@ class SyntheticCacheTest(unittest.TestCase):
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
# Scenario 1: Update within window, no slide yet
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
@ -1121,13 +1127,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Sliding Scenario 1 failed",
)
# Scenario 2: Update causing first slide
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
@ -1142,7 +1148,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"HybridCache Sliding Scenario 2 failed",
)
@ -1155,13 +1161,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 3 failed",
)
# Scenario 4: Long prompt handling
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
hybrid_cache.update(
key_states=long_prefill,
@ -1170,7 +1176,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 4 failed",
)
@ -1190,10 +1196,10 @@ class SyntheticCacheTest(unittest.TestCase):
cache = DynamicCache()
cache.update(prefill, prefill, 0)
cache.update(update3, update3, 0)
self.assertEqual(cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed")
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed")
cache.update(update4, update4, 0)
self.assertEqual(
cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed"
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed"
)
# Scenario 2: prefill and update for two layers independently
@ -1210,8 +1216,10 @@ class SyntheticCacheTest(unittest.TestCase):
cache.update(update4, update4, 0)
cache.update(update4_1, update4_1, 1)
self.assertEqual(
cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed"
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed"
)
self.assertEqual(
cache.key_cache[1][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0], "DynamicCache Scenario 2 layer 1 failed"
cache.layers[1].keys[0, 0, :, 0].tolist(),
[10.0, 20.0, 30.0, 40.0],
"DynamicCache Scenario 2 layer 1 failed",
)

View File

@ -956,8 +956,9 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str):
idx += 1
if "".join(source[start_idx:idx])[:-1] != old_doc_args:
# Args are not fully defined in the docstring of this object
return
raise ValueError(
f"Expected\n{old_doc_args}\nbut got\n{''.join(source[start_idx:idx])[:-1]}\n in {find_source_file(obj)}: {obj.__name__}"
)
obj_file = find_source_file(obj)
with open(obj_file, "r", encoding="utf-8") as f: