mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
raushan review, arthur review
This commit is contained in:
parent
26c28af616
commit
16a6624087
@ -82,24 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s
|
||||
|
||||
## Cache storage implementation
|
||||
|
||||
The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`].
|
||||
Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`.
|
||||
|
||||
Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated.
|
||||
|
||||
In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`.
|
||||
- `key_cache`: A list of tensors, one for each layer.
|
||||
- `value_cache`: A list of tensors, one for each layer.
|
||||
The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token:
|
||||
|
||||
When new tokens are processed:
|
||||
|
||||
1. For each layer, the new key and value states are concatenated with the existing cache.
|
||||
```py
|
||||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
|
||||
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)
|
||||
```
|
||||
|
||||
2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token.
|
||||
|
||||
3. The cache maintains a count of seen tokens through `self._seen_tokens`. This is updated when the first layer processes a new token.
|
||||
Other layers like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added.
|
||||
|
||||
The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token.
|
||||
|
||||
@ -143,7 +137,7 @@ The legacy format is essentially the same data structure but organized different
|
||||
- The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`.
|
||||
- The format is less flexible and doesn't support features like quantization or offloading.
|
||||
|
||||
If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format.
|
||||
If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~Cache.from_legacy_cache`] and [`Cache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1951,7 +1951,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
|
||||
layer_device_map = self._get_layer_device_map_for_cache_init()
|
||||
cache_kwargs = {
|
||||
"config": self.config.get_text_config(),
|
||||
"model_config": self.config.get_text_config(),
|
||||
"max_batch_size": batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"dtype": cache_dtype,
|
||||
|
@ -275,15 +275,15 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
|
||||
self.model = model
|
||||
self.static_cache = StaticCache(
|
||||
config=self.model.config,
|
||||
model_config=self.model.config,
|
||||
max_batch_size=self.model.generation_config.cache_config.batch_size,
|
||||
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
|
||||
device=self.model.generation_config.cache_config.device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
for i in range(len(self.static_cache.key_cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
|
||||
for i in range(len(self.static_cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
|
||||
"""
|
||||
@ -404,7 +404,7 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||
|
||||
# Initialize the HybridCache
|
||||
self.cache = HybridCache(
|
||||
config=self.model.config,
|
||||
model_config=self.model.config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_cache_len=max_cache_len,
|
||||
device=self.model.device,
|
||||
@ -412,9 +412,9 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||
)
|
||||
|
||||
# Register all key and value cache tensors as buffers
|
||||
for i in range(len(self.cache.key_cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False)
|
||||
for i in range(len(self.cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -550,7 +550,7 @@ class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
|
||||
|
||||
# Initialize static cache
|
||||
self.static_cache = StaticCache(
|
||||
config=self.config,
|
||||
model_config=self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=max_static_cache_length,
|
||||
device="cpu",
|
||||
@ -558,9 +558,9 @@ class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
|
||||
)
|
||||
|
||||
# Register cache buffers to make them exportable
|
||||
for i in range(len(self.static_cache.key_cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
|
||||
for i in range(len(self.static_cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
|
||||
|
||||
def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
|
||||
# Get outputs from decoder
|
||||
|
@ -692,10 +692,10 @@ def create_causal_mask(
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
|
||||
layer_idx = past_key_values.is_sliding.index(False)
|
||||
else:
|
||||
layer_idx = 0
|
||||
is_sliding = []
|
||||
if past_key_values is not None:
|
||||
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
|
||||
layer_idx = is_sliding.index(True) if True in is_sliding else 0
|
||||
|
||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||
@ -774,10 +774,10 @@ def create_sliding_window_causal_mask(
|
||||
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
|
||||
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
|
||||
layer_idx = past_key_values.is_sliding.index(True)
|
||||
else:
|
||||
layer_idx = 0
|
||||
is_sliding = []
|
||||
if past_key_values is not None:
|
||||
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
|
||||
layer_idx = is_sliding.index(True) if True in is_sliding else 0
|
||||
|
||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||
@ -861,10 +861,10 @@ def create_chunked_causal_mask(
|
||||
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
|
||||
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
|
||||
layer_idx = past_key_values.is_sliding.index(True)
|
||||
else:
|
||||
layer_idx = 0
|
||||
is_sliding = []
|
||||
if past_key_values is not None:
|
||||
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
|
||||
layer_idx = is_sliding.index(True) if True in is_sliding else 0
|
||||
|
||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||
|
@ -230,8 +230,8 @@ class BartAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -1293,8 +1293,8 @@ class BigBirdPegasusDecoderAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -207,8 +207,8 @@ class BioGptAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -229,8 +229,8 @@ class BlenderbotAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -213,8 +213,8 @@ class BlenderbotSmallAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -356,8 +356,8 @@ class DiaCrossAttention(nn.Module):
|
||||
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
|
||||
if past_key_values is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
|
||||
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
|
||||
key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
|
||||
value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
||||
|
@ -182,8 +182,8 @@ class DiaCrossAttention(nn.Module):
|
||||
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
|
||||
if past_key_values is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
|
||||
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
|
||||
key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
|
||||
value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
||||
|
@ -1332,8 +1332,8 @@ class Gemma3nTextAttention(nn.Module):
|
||||
else:
|
||||
indices = cache_position
|
||||
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices]
|
||||
value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices]
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
@ -1774,8 +1774,8 @@ class Gemma3nTextAttention(Gemma3Attention):
|
||||
else:
|
||||
indices = cache_position
|
||||
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices]
|
||||
value_states = past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices]
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
@ -728,8 +728,9 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if past_key_values is not None:
|
||||
past_key_values.key_cache = past_key_values.key_cache.to(hidden_states.device)
|
||||
past_key_values.value_cache = past_key_values.value_cache.to(hidden_states.device)
|
||||
for layer in past_key_values.layers:
|
||||
layer.keys = layer.keys.to(hidden_states.device)
|
||||
layer.values = layer.values.to(hidden_states.device)
|
||||
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if causal_mask is not None:
|
||||
|
@ -484,8 +484,8 @@ class InformerAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
@ -601,8 +601,8 @@ class InformerProbSparseAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -290,8 +290,8 @@ class InformerProbSparseAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -478,8 +478,8 @@ class LongT5Attention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k(current_states)
|
||||
value_states = self.v(current_states)
|
||||
|
@ -294,8 +294,8 @@ class M2M100Attention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -229,8 +229,8 @@ class MarianAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -239,8 +239,8 @@ class MBartAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -105,16 +105,14 @@ class MiniMaxCache(DynamicCache):
|
||||
if self.linear_cache[layer_idx] != []:
|
||||
self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
||||
else:
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
||||
self.layers[layer_idx].batch_repeat_interleave(repeats)
|
||||
|
||||
def batch_select_indices(self, indices: torch.Tensor):
|
||||
for layer_idx in range(len(self)):
|
||||
if self.linear_cache[layer_idx] != []:
|
||||
self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
|
||||
else:
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
||||
self.layers[layer_idx].batch_select_indices(indices)
|
||||
|
||||
def crop(self, max_length: int):
|
||||
raise RuntimeError("MiniMaxCache doesnot support `crop` method")
|
||||
|
@ -215,16 +215,14 @@ class MiniMaxCache(DynamicCache):
|
||||
if self.linear_cache[layer_idx] != []:
|
||||
self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
||||
else:
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
||||
self.layers[layer_idx].batch_repeat_interleave(repeats)
|
||||
|
||||
def batch_select_indices(self, indices: torch.Tensor):
|
||||
for layer_idx in range(len(self)):
|
||||
if self.linear_cache[layer_idx] != []:
|
||||
self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
|
||||
else:
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
||||
self.layers[layer_idx].batch_select_indices(indices)
|
||||
|
||||
def crop(self, max_length: int):
|
||||
raise RuntimeError("MiniMaxCache doesnot support `crop` method")
|
||||
|
@ -496,8 +496,8 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
)
|
||||
elif cache_position[0] != 0:
|
||||
key_states, value_states = (
|
||||
past_key_value.key_cache[self.layer_idx],
|
||||
past_key_value.value_cache[self.layer_idx],
|
||||
past_key_value.layers[self.layer_idx].keys,
|
||||
past_key_value.layers[self.layer_idx].values,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -235,8 +235,8 @@ class MoonshineAttention(nn.Module):
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
key_states = past_key_value.layers[self.layer_idx].keys
|
||||
value_states = past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = (
|
||||
self.k_proj(current_states)
|
||||
|
@ -331,8 +331,8 @@ class MoonshineAttention(GlmAttention):
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
key_states = past_key_value.layers[self.layer_idx].keys
|
||||
value_states = past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = (
|
||||
self.k_proj(current_states)
|
||||
|
@ -376,8 +376,8 @@ class MT5Attention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k(current_states)
|
||||
value_states = self.v(current_states)
|
||||
|
@ -228,8 +228,8 @@ class PegasusAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -249,8 +249,8 @@ class PegasusXAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -770,8 +770,8 @@ class Pix2StructTextAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.key(current_states)
|
||||
value_states = self.value(current_states)
|
||||
|
@ -425,8 +425,8 @@ class PLBartAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -320,8 +320,8 @@ class Pop2PianoAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k(current_states)
|
||||
value_states = self.v(current_states)
|
||||
|
@ -513,8 +513,8 @@ class SwitchTransformersAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k(current_states)
|
||||
value_states = self.v(current_states)
|
||||
|
@ -501,8 +501,8 @@ class T5Attention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k(current_states)
|
||||
value_states = self.v(current_states)
|
||||
|
@ -352,8 +352,8 @@ class T5GemmaCrossAttention(nn.Module):
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
# cross-attention: reuse cached states
|
||||
else:
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
|
@ -308,8 +308,8 @@ class T5GemmaCrossAttention(Gemma2Attention):
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
# cross-attention: reuse cached states
|
||||
else:
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
|
@ -394,8 +394,8 @@ class TimeSeriesTransformerAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
|
@ -599,8 +599,8 @@ class UdopAttention(nn.Module):
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k(current_states)
|
||||
value_states = self.v(current_states)
|
||||
|
@ -285,8 +285,8 @@ class UMT5Attention(nn.Module):
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k(current_states)
|
||||
value_states = self.v(current_states)
|
||||
|
@ -1140,8 +1140,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
for layer_idx in range(self.config.decoder_layers):
|
||||
layer_past_key_values = []
|
||||
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
|
||||
for v in [cache_cls.key_cache, cache_cls.value_cache]:
|
||||
layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
|
||||
for v in [cache_cls.layers[layer_idx].keys, cache_cls.layers[layer_idx].values]:
|
||||
layer_past_key_values.append(v[batch_idx][None].cpu())
|
||||
all_past_key_values.append(tuple(layer_past_key_values))
|
||||
return tuple(all_past_key_values)
|
||||
else:
|
||||
|
@ -329,8 +329,8 @@ class WhisperAttention(nn.Module):
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
key_states = past_key_value.layers[self.layer_idx].keys
|
||||
value_states = past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
|
||||
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
|
||||
|
@ -1595,9 +1595,7 @@ class GenerationTesterMixin:
|
||||
# 3. Check cache shapes
|
||||
# 3.1. Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
num_cache_decoder_layers = (
|
||||
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
|
||||
)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@ -1605,30 +1603,30 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
|
||||
self_attention_layer_keys = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
|
||||
)
|
||||
self_attention_layer_value_cache = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
|
||||
self_attention_layer_values = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
|
||||
)
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
# Cross attention (ignore 3rd dim, see default shape preparation)
|
||||
cross_attention_layer_key_cache = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
|
||||
cross_attention_layer_keys = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
|
||||
)
|
||||
cross_attention_layer_value_cache = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
|
||||
cross_attention_layer_values = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
|
||||
)
|
||||
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
|
||||
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
|
||||
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
|
||||
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
|
||||
|
||||
# 3.2. Decoder-only checks
|
||||
else:
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@ -1636,10 +1634,10 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
|
||||
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
|
||||
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
@ -1798,8 +1796,8 @@ class GenerationTesterMixin:
|
||||
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
|
||||
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
|
||||
self.assertIsInstance(outputs.past_key_values, StaticCache)
|
||||
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
|
||||
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
|
||||
self.assertEqual(len(outputs.past_key_values), num_hidden_layers)
|
||||
self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
@ -2029,8 +2027,8 @@ class GenerationTesterMixin:
|
||||
num_hidden_layers = text_config.num_hidden_layers
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
|
||||
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
|
||||
self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers)
|
||||
self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape)
|
||||
|
||||
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||
@ -2612,12 +2610,12 @@ class GenerationTesterMixin:
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
[layer.keys.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
[layer.values.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
|
||||
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
||||
@ -3976,13 +3974,13 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
keys_0 = results.past_key_values.layers[0].keys
|
||||
values_0 = results.past_key_values.layers[0].values
|
||||
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
keys_1 = results.past_key_values.layers[1].keys
|
||||
values_1 = results.past_key_values.layers[1].values
|
||||
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_multi_accelerator
|
||||
@ -4054,13 +4052,13 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
keys_0 = results.past_key_values.layers[0].keys
|
||||
values_0 = results.past_key_values.layers[0].values
|
||||
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
keys_1 = results.past_key_values.layers[1].keys
|
||||
values_1 = results.past_key_values.layers[1].values
|
||||
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
|
||||
|
||||
@slow
|
||||
def test_padding_input_contrastive_search_gpt2(self):
|
||||
|
@ -440,13 +440,11 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
# difference: last dim
|
||||
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
v_embed_dim = config.v_head_dim
|
||||
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
||||
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
||||
self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
||||
self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
||||
# build the full cache shapes
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
all_cache_shapes = [
|
||||
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
|
||||
]
|
||||
all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)]
|
||||
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
|
||||
|
||||
@require_torch_large_accelerator
|
||||
|
@ -399,12 +399,12 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
[layer.keys.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
[layer.values.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
|
||||
def _check_scores(self, batch_size, scores, generated_length, config):
|
||||
|
@ -235,8 +235,8 @@ class GPTNeoXModelTester:
|
||||
"""Deep copy a DynamicCache to reuse the same one multiple times."""
|
||||
new_cache = cache
|
||||
for i in range(len(cache)):
|
||||
new_cache.key_cache[i] = cache.key_cache[i].clone()
|
||||
new_cache.value_cache[i] = cache.value_cache[i].clone()
|
||||
new_cache.layers[i].keys = cache.layers[i].keys.clone()
|
||||
new_cache.layers[i].values = cache.layers[i].values.clone()
|
||||
|
||||
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
|
||||
# We need to run both on a copy of the cache, otherwise it is modified in-place
|
||||
|
@ -272,7 +272,7 @@ class T5GemmaModelTester:
|
||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertIsNotNone(decoder_past)
|
||||
self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers)
|
||||
self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers)
|
||||
self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers)
|
||||
|
||||
def check_prepare_lm_labels_via_shift_left(
|
||||
self,
|
||||
@ -1069,9 +1069,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# 3. Check cache shapes
|
||||
# 3.1. Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
num_cache_decoder_layers = (
|
||||
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
|
||||
)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@ -1079,30 +1077,30 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
|
||||
self_attention_layer_keys = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
|
||||
)
|
||||
self_attention_layer_value_cache = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
|
||||
self_attention_layer_values = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
|
||||
)
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
# Cross attention (ignore 3rd dim, see default shape preparation)
|
||||
cross_attention_layer_key_cache = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
|
||||
cross_attention_layer_keys = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
|
||||
)
|
||||
cross_attention_layer_value_cache = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
|
||||
cross_attention_layer_values = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
|
||||
)
|
||||
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
|
||||
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
|
||||
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
|
||||
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
|
||||
|
||||
# 3.2. Decoder-only checks
|
||||
else:
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@ -1110,10 +1108,10 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
|
||||
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
|
||||
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
@unittest.skip("Mismatch issue doesn't exist in T5Gemma.")
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
|
@ -171,7 +171,9 @@ class CacheTest(unittest.TestCase):
|
||||
return random_keys, random_values
|
||||
|
||||
mha_config = LlamaConfig(num_attention_heads=32)
|
||||
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
mha_static_cache = StaticCache(
|
||||
model_config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device
|
||||
)
|
||||
cached_keys, cached_values = mha_static_cache.update(
|
||||
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||
)
|
||||
@ -179,7 +181,9 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
||||
|
||||
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
||||
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
gqa_static_cache = StaticCache(
|
||||
model_config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device
|
||||
)
|
||||
cached_keys, cached_values = gqa_static_cache.update(
|
||||
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||
)
|
||||
@ -187,7 +191,9 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
||||
|
||||
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
||||
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
mqa_static_cache = StaticCache(
|
||||
model_config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device
|
||||
)
|
||||
cached_keys, cached_values = mqa_static_cache.update(
|
||||
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||
)
|
||||
@ -323,7 +329,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertIsInstance(gen_out.past_key_values, QuantizedCache)
|
||||
processor = gen_out.past_key_values.processor
|
||||
processor = gen_out.past_key_values.cache_processor
|
||||
if backend == "quanto":
|
||||
self.assertIsInstance(processor, QuantoQuantizedCacheProcessor)
|
||||
elif backend == "hqq":
|
||||
@ -332,12 +338,10 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, expected_generation)
|
||||
|
||||
self.assertTrue(len(processor._quantized_key_cache) > 0)
|
||||
self.assertTrue(len(processor._quantized_keys) > 0)
|
||||
|
||||
# Check that something is actually quantized
|
||||
has_been_quantized = any(
|
||||
(q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_key_cache
|
||||
)
|
||||
has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys)
|
||||
self.assertTrue(has_been_quantized)
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
@ -654,7 +658,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
past_key_values=DynamicCache(),
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
|
||||
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||
self.assertEqual(
|
||||
3,
|
||||
@ -675,11 +679,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
|
||||
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2))
|
||||
|
||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers):
|
||||
self.assertTrue(torch.allclose(l1.keys, l2.keys))
|
||||
self.assertTrue(torch.allclose(l1.values, l2.values))
|
||||
|
||||
def test_dynamic_cache_exportability_multiple_run(self):
|
||||
# When exporting with DynamicCache, you should export two graphs:
|
||||
@ -703,7 +705,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
past_key_values=DynamicCache(),
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
|
||||
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||
self.assertEqual(
|
||||
3,
|
||||
@ -728,9 +730,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
shapes = torch.export.ShapesCollection()
|
||||
dyn = torch.export.Dim("seq", max=512)
|
||||
|
||||
for ix in range(len(past_key_values.key_cache)):
|
||||
shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None)
|
||||
shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None)
|
||||
for ix in range(len(past_key_values)):
|
||||
shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None)
|
||||
shapes[past_key_values.layers[ix].values] = (None, None, dyn, None)
|
||||
|
||||
ep_second = torch.export.export(
|
||||
model,
|
||||
@ -771,11 +773,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2))
|
||||
|
||||
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers):
|
||||
self.assertTrue(torch.allclose(l1.keys, l2.keys))
|
||||
self.assertTrue(torch.allclose(l1.values, l2.values))
|
||||
|
||||
def test_static_cache_exportability(self):
|
||||
"""
|
||||
@ -922,7 +922,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
|
||||
def test_static_cache_out_of_bounds(self):
|
||||
"""Test StaticCache raises IndexError for out-of-bounds positions."""
|
||||
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
@ -944,7 +944,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
update pos 3: [1.0, 2.0, 3.0, 4.0]
|
||||
"""
|
||||
# Scenario 1: Fill up to near capacity
|
||||
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
static_cache = StaticCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None)
|
||||
static_cache.update(
|
||||
@ -954,7 +954,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||
)
|
||||
self.assertEqual(
|
||||
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
|
||||
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
|
||||
)
|
||||
|
||||
# Scenario 2: Fill to capacity
|
||||
@ -965,7 +965,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
self.assertEqual(
|
||||
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
|
||||
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
|
||||
)
|
||||
|
||||
def test_sliding_window_cache(self):
|
||||
@ -984,7 +984,9 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||
"""
|
||||
# Scenario 1: Update within window, no slide yet
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
sliding_cache = SlidingWindowCache(
|
||||
model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len
|
||||
)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=prefill,
|
||||
@ -999,13 +1001,15 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"SlidingWindowCache Scenario 1 failed",
|
||||
)
|
||||
|
||||
# Scenario 2: Update causing slide
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
sliding_cache = SlidingWindowCache(
|
||||
model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len
|
||||
)
|
||||
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=prefill,
|
||||
@ -1020,13 +1024,15 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[2.0, 3.0, 4.0, 5.0],
|
||||
"SlidingWindowCache Scenario 2 failed",
|
||||
)
|
||||
|
||||
# Scenario 3: Long prompt handling
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
sliding_cache = SlidingWindowCache(
|
||||
model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len
|
||||
)
|
||||
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=long_prefill,
|
||||
@ -1035,7 +1041,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"SlidingWindowCache Scenario 3 failed",
|
||||
)
|
||||
@ -1054,7 +1060,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
|
||||
|
||||
# Scenario 1
|
||||
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache_static_mode = HybridCache(model_config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
hybrid_cache_static_mode.update(
|
||||
key_states=prefill,
|
||||
@ -1069,7 +1075,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"HybridCache Static Scenario 1 failed",
|
||||
)
|
||||
@ -1082,7 +1088,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
"HybridCache Static Scenario 2 failed",
|
||||
)
|
||||
@ -1106,7 +1112,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||
"""
|
||||
# Scenario 1: Update within window, no slide yet
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=prefill,
|
||||
@ -1121,13 +1127,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"HybridCache Sliding Scenario 1 failed",
|
||||
)
|
||||
|
||||
# Scenario 2: Update causing first slide
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=prefill,
|
||||
@ -1142,7 +1148,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[2.0, 3.0, 4.0, 5.0],
|
||||
"HybridCache Sliding Scenario 2 failed",
|
||||
)
|
||||
@ -1155,13 +1161,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"HybridCache Sliding Scenario 3 failed",
|
||||
)
|
||||
|
||||
# Scenario 4: Long prompt handling
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache = HybridCache(model_config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=long_prefill,
|
||||
@ -1170,7 +1176,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"HybridCache Sliding Scenario 4 failed",
|
||||
)
|
||||
@ -1190,10 +1196,10 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache = DynamicCache()
|
||||
cache.update(prefill, prefill, 0)
|
||||
cache.update(update3, update3, 0)
|
||||
self.assertEqual(cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed")
|
||||
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed")
|
||||
cache.update(update4, update4, 0)
|
||||
self.assertEqual(
|
||||
cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed"
|
||||
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed"
|
||||
)
|
||||
|
||||
# Scenario 2: prefill and update for two layers independently
|
||||
@ -1210,8 +1216,10 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache.update(update4, update4, 0)
|
||||
cache.update(update4_1, update4_1, 1)
|
||||
self.assertEqual(
|
||||
cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed"
|
||||
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed"
|
||||
)
|
||||
self.assertEqual(
|
||||
cache.key_cache[1][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0], "DynamicCache Scenario 2 layer 1 failed"
|
||||
cache.layers[1].keys[0, 0, :, 0].tolist(),
|
||||
[10.0, 20.0, 30.0, 40.0],
|
||||
"DynamicCache Scenario 2 layer 1 failed",
|
||||
)
|
||||
|
@ -956,8 +956,9 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str):
|
||||
idx += 1
|
||||
|
||||
if "".join(source[start_idx:idx])[:-1] != old_doc_args:
|
||||
# Args are not fully defined in the docstring of this object
|
||||
return
|
||||
raise ValueError(
|
||||
f"Expected\n{old_doc_args}\nbut got\n{''.join(source[start_idx:idx])[:-1]}\n in {find_source_file(obj)}: {obj.__name__}"
|
||||
)
|
||||
|
||||
obj_file = find_source_file(obj)
|
||||
with open(obj_file, "r", encoding="utf-8") as f:
|
||||
|
Loading…
Reference in New Issue
Block a user