joao review: minor things

This commit is contained in:
Manuel de Prada Corral 2025-07-02 20:52:22 +02:00
parent 16a6624087
commit aec9ccd6ea
5 changed files with 137 additions and 140 deletions

View File

@ -37,7 +37,7 @@ class CacheProcessor:
cache (`Cache`): The cache instance this processor will be applied to.
**kwargs: Additional arguments that may be needed for initialization.
"""
raise NotImplementedError("Make sure to implement `init` in a subclass.")
raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.")
def pre_update(
self,
@ -55,7 +55,7 @@ class CacheProcessor:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
layer_idx (`int`): The index of the layer to cache the states for.
cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache.
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states.
@ -78,7 +78,7 @@ class CacheProcessor:
key_states (`torch.Tensor`): The key states that were cached.
value_states (`torch.Tensor`): The value states that were cached.
layer_idx (`int`): The index of the layer that was updated.
cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache.
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return.
@ -86,24 +86,15 @@ class CacheProcessor:
return key_tensors, value_tensors
class CacheLayer:
class CacheLayerMixin:
"""Base, abstract class for a single layer's cache."""
is_compileable = False
def __init__(
self,
config: Optional["CacheConfig"] = None,
):
self.keys = None
self.values = None
@classmethod
def from_kv(cls, keys: torch.Tensor, values: torch.Tensor) -> None:
cache = cls()
cache.keys = keys
cache.values = values
return cache
def __repr__(self):
key_repr = "None" if self.keys is None else f"t({tuple(self.keys.shape)})"
value_repr = "None" if self.values is None else f"t({tuple(self.values.shape)})"
return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})"
def update(
self,
@ -111,28 +102,16 @@ class CacheLayer:
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Updates KV cache, returns updated K and V for this layer."""
raise NotImplementedError("Make sure to implement `update` in a subclass.")
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache.
Early stops since first layer is enough to compute sequence length"""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_cache_shape()
previous_seq_length, _ = self.get_seq_length(layer_idx)
if max_length != -1 and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
"""Updates KV cache, returns updated keys/values for this layer."""
raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.")
def get_max_cache_shape(self) -> int:
"""Returns the maximum sequence length (i.e. max capacity) of this layer's cache."""
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")
raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.")
def reset(self) -> None:
"""Resets this layer's cache."""
raise NotImplementedError("Make sure to implement `reset` in a subclass.")
raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.")
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
"""Reorders this layer's cache for beam search."""
@ -143,13 +122,11 @@ class CacheLayer:
device = self.values.device
self.values = self.values.index_select(0, beam_idx.to(device))
def __repr__(self):
key_repr = "None" if self.keys is None else f"t({tuple(self.keys.shape)})"
value_repr = "None" if self.values is None else f"t({tuple(self.values.shape)})"
return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})"
class CacheBase:
layers = None
def update(
self,
key_states: torch.Tensor,
@ -180,13 +157,11 @@ class Cache(CacheBase):
- SlidingWindow layers are limited to sliding window size, Static layers use full max_cache_len
"""
layers = []
def __init__(
self,
model_config: Optional[PretrainedConfig] = None,
cache_processor: Optional[CacheProcessor] = None,
layer_classes: Optional[list[type[CacheLayer]]] = None,
layer_classes: Optional[list[type[CacheLayerMixin]]] = None,
*args,
**kwargs,
):
@ -207,7 +182,7 @@ class Cache(CacheBase):
- `dtype` (`torch.dtype`): Data type for cache tensors
- `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping
"""
self.layers: list[CacheLayer] = []
self.layers: list[CacheLayerMixin] = []
self.cache_processor = cache_processor
if (
@ -225,43 +200,6 @@ class Cache(CacheBase):
if self.cache_processor is not None:
self.cache_processor.init(self, **kwargs)
def append_new_layers(self, layer_idx):
"""
Appends layers to the cache until the layer `layer_idx` is reached.
Used in prefill and for skipped layers.
"""
while len(self.layers) <= layer_idx:
self.layers.append(
self.layer_classes[layer_idx % len(self.layer_classes)](self.config.for_layer(layer_idx))
)
def _update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
cache to be created.
Return:
A tuple containing the updated key and value states.
"""
self.append_new_layers(layer_idx)
return self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
@ -336,6 +274,46 @@ class Cache(CacheBase):
else:
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
def __repr__(self):
return f"{self.__class__.__name__}(layers={self.layers})"
def append_new_layers(self, layer_idx):
"""
Appends layers to the cache until the layer `layer_idx` is reached.
Used in prefill and for skipped layers.
"""
while len(self.layers) <= layer_idx:
self.layers.append(
self.layer_classes[layer_idx % len(self.layer_classes)](self.config.for_layer(layer_idx))
)
def _update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`dict[str, Any]`, *optional*):
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
cache to be created.
Return:
A tuple containing the updated key and value states.
"""
self.append_new_layers(layer_idx)
return self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
def get_seq_length(self, layer_idx: int = 0) -> int:
"""Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position"""
if layer_idx >= len(self.layers):
@ -354,42 +332,25 @@ class Cache(CacheBase):
"""
return self.layers[layer_idx].get_mask_sizes(cache_position)
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
"""Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer in self.layers:
if layer is not None:
legacy_cache += ((layer.keys, layer.values),)
return legacy_cache
@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None
) -> "Cache":
"""Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
backward compatibility."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
def __repr__(self):
return f"{self.__class__.__name__}(layers={self.layers})"
@dataclass
class CacheConfig:
"""
Base class for cache configs
"""
"""Base class for cache configs"""
def __init__(self, num_layers: Optional[int] = None, cache_implementation: Optional[str] = None):
self.num_layers = num_layers
self.cache_implementation = cache_implementation
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
def __iter__(self):
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
@classmethod
def from_model_config(
cls,
@ -497,16 +458,6 @@ class CacheConfig:
"""
return copy.deepcopy(self.__dict__)
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
def __iter__(self):
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def to_json_string(self):
"""
Serializes this instance to a JSON formatted string.
@ -640,9 +591,7 @@ class QuantizedCacheConfig(CacheConfig):
@dataclass
class StaticCacheConfig(CacheConfig):
"""
Configuration class for static and sliding window cache settings.
"""
"""Configuration class for static and sliding window cache settings."""
batch_size: Optional[int] = None
max_cache_len: Optional[int] = None
@ -669,9 +618,7 @@ class StaticCacheConfig(CacheConfig):
logger.warning_once("`dtype` not set in cache initialization, using default `float32`")
def for_layer(self, layer_idx: int):
"""
Returns a StaticCacheConfig for a given layer index.
"""
"""Returns a StaticCacheConfig for a given layer index."""
device = self.layer_device_map[layer_idx] if self.layer_device_map is not None else self.device
return StaticCacheConfig(
self.batch_size,
@ -724,11 +671,19 @@ class StaticCacheConfig(CacheConfig):
)
class DynamicLayer(CacheLayer):
class DynamicLayer(CacheLayerMixin):
"""
A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.
"""
keys, values = None, None
@classmethod
def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> None:
cache = cls()
cache.keys = keys
cache.values = values
return cache
def update(
self,
@ -744,7 +699,7 @@ class DynamicLayer(CacheLayer):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
cache_kwargs (`dict[str, Any]`, `optional`):
cache_kwargs (`dict[str, Any]`, *optional*):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`.
Return:
@ -782,8 +737,10 @@ class DynamicLayer(CacheLayer):
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
def crop(self, max_length: int) -> None:
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens."""
"""
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens.
"""
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
@ -849,9 +806,35 @@ class DynamicCache(Cache):
# compatibility. The name of the argument doesn't matter.
if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data:
self.layers.append(DynamicLayer.from_kv(key_states, value_states))
self.layers.append(DynamicLayer.from_tensors(key_states, value_states))
super().__init__(*args, **kwargs)
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility.
"""
legacy_cache = ()
for layer in self.layers:
if layer is not None:
legacy_cache += ((layer.keys, layer.values),)
return legacy_cache
@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None
) -> "Cache":
"""
Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
backward compatibility.
"""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
# Utilities for `DynamicCache` <> torch.export support
def _flatten_dynamic_cache(
@ -935,7 +918,7 @@ class OffloadedCache(DynamicCache):
super().__init__(cache_processor=OffloadedCacheProcessor(), model_config=model_config)
class StaticLayer(CacheLayer):
class StaticLayer(CacheLayerMixin):
is_compileable = True
is_sliding = False
@ -1304,14 +1287,18 @@ class EncoderDecoderCache(Cache):
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
"""
Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.
"""
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)
def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
"""
Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`
"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)

View File

@ -350,10 +350,6 @@ class PhimoeFlashAttention2(PhimoeAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

View File

@ -143,6 +143,9 @@ class ZambaHybridDynamicCache(Cache):
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
def __len__(self):
return len(self.key_cache)
# Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update
def update(
self,

View File

@ -147,6 +147,9 @@ class Zamba2HybridDynamicCache(Cache):
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
def __len__(self):
return len(self.key_cache)
def update(
self,
key_states: torch.Tensor,

View File

@ -1626,7 +1626,7 @@ class GenerationTesterMixin:
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
num_cache_decoder_layers = len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@ -1634,8 +1634,16 @@ class GenerationTesterMixin:
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
if is_legacy_cache:
self_attention_layer_keys = past_kv[i][0]
self_attention_layer_values = past_kv[i][1]
elif past_kv.layers is None:
# Cache is lot layered (i.e, Mamba derivatives)
self_attention_layer_keys = past_kv.key_cache[i]
self_attention_layer_values = past_kv.value_cache[i]
else:
self_attention_layer_keys = past_kv.layers[i].keys
self_attention_layer_values = past_kv.layers[i].values
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])