mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
joao review: minor things
This commit is contained in:
parent
16a6624087
commit
aec9ccd6ea
@ -37,7 +37,7 @@ class CacheProcessor:
|
||||
cache (`Cache`): The cache instance this processor will be applied to.
|
||||
**kwargs: Additional arguments that may be needed for initialization.
|
||||
"""
|
||||
raise NotImplementedError("Make sure to implement `init` in a subclass.")
|
||||
raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.")
|
||||
|
||||
def pre_update(
|
||||
self,
|
||||
@ -55,7 +55,7 @@ class CacheProcessor:
|
||||
key_states (`torch.Tensor`): The new key states to cache.
|
||||
value_states (`torch.Tensor`): The new value states to cache.
|
||||
layer_idx (`int`): The index of the layer to cache the states for.
|
||||
cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache.
|
||||
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
|
||||
|
||||
Returns:
|
||||
tuple[`torch.Tensor`, `torch.Tensor`]: The potentially modified key and value states.
|
||||
@ -78,7 +78,7 @@ class CacheProcessor:
|
||||
key_states (`torch.Tensor`): The key states that were cached.
|
||||
value_states (`torch.Tensor`): The value states that were cached.
|
||||
layer_idx (`int`): The index of the layer that was updated.
|
||||
cache_kwargs (`dict[str, Any]`, `optional`): Additional arguments for the cache.
|
||||
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
|
||||
|
||||
Returns:
|
||||
tuple[`torch.Tensor`, `torch.Tensor`]: The final key and value states to return.
|
||||
@ -86,24 +86,15 @@ class CacheProcessor:
|
||||
return key_tensors, value_tensors
|
||||
|
||||
|
||||
class CacheLayer:
|
||||
class CacheLayerMixin:
|
||||
"""Base, abstract class for a single layer's cache."""
|
||||
|
||||
is_compileable = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional["CacheConfig"] = None,
|
||||
):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
|
||||
@classmethod
|
||||
def from_kv(cls, keys: torch.Tensor, values: torch.Tensor) -> None:
|
||||
cache = cls()
|
||||
cache.keys = keys
|
||||
cache.values = values
|
||||
return cache
|
||||
def __repr__(self):
|
||||
key_repr = "None" if self.keys is None else f"t({tuple(self.keys.shape)})"
|
||||
value_repr = "None" if self.values is None else f"t({tuple(self.values.shape)})"
|
||||
return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})"
|
||||
|
||||
def update(
|
||||
self,
|
||||
@ -111,28 +102,16 @@ class CacheLayer:
|
||||
value_states: torch.Tensor,
|
||||
cache_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Updates KV cache, returns updated K and V for this layer."""
|
||||
raise NotImplementedError("Make sure to implement `update` in a subclass.")
|
||||
|
||||
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Given the sequence length of the new inputs, returns the usable length of the cache.
|
||||
Early stops since first layer is enough to compute sequence length"""
|
||||
# Cache without size limit -> all cache is usable
|
||||
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
||||
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
||||
max_length = self.get_max_cache_shape()
|
||||
previous_seq_length, _ = self.get_seq_length(layer_idx)
|
||||
if max_length != -1 and previous_seq_length + new_seq_length > max_length:
|
||||
return max_length - new_seq_length
|
||||
return previous_seq_length
|
||||
"""Updates KV cache, returns updated keys/values for this layer."""
|
||||
raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.")
|
||||
|
||||
def get_max_cache_shape(self) -> int:
|
||||
"""Returns the maximum sequence length (i.e. max capacity) of this layer's cache."""
|
||||
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")
|
||||
raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.")
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets this layer's cache."""
|
||||
raise NotImplementedError("Make sure to implement `reset` in a subclass.")
|
||||
raise NotImplementedError(f"Make sure to implement `reset` in {self.__class__.__name__}.")
|
||||
|
||||
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
||||
"""Reorders this layer's cache for beam search."""
|
||||
@ -143,13 +122,11 @@ class CacheLayer:
|
||||
device = self.values.device
|
||||
self.values = self.values.index_select(0, beam_idx.to(device))
|
||||
|
||||
def __repr__(self):
|
||||
key_repr = "None" if self.keys is None else f"t({tuple(self.keys.shape)})"
|
||||
value_repr = "None" if self.values is None else f"t({tuple(self.values.shape)})"
|
||||
return f"{self.__class__.__name__}(K={key_repr}, V={value_repr})"
|
||||
|
||||
|
||||
class CacheBase:
|
||||
|
||||
layers = None
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
@ -180,13 +157,11 @@ class Cache(CacheBase):
|
||||
- SlidingWindow layers are limited to sliding window size, Static layers use full max_cache_len
|
||||
"""
|
||||
|
||||
layers = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
cache_processor: Optional[CacheProcessor] = None,
|
||||
layer_classes: Optional[list[type[CacheLayer]]] = None,
|
||||
layer_classes: Optional[list[type[CacheLayerMixin]]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@ -207,7 +182,7 @@ class Cache(CacheBase):
|
||||
- `dtype` (`torch.dtype`): Data type for cache tensors
|
||||
- `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping
|
||||
"""
|
||||
self.layers: list[CacheLayer] = []
|
||||
self.layers: list[CacheLayerMixin] = []
|
||||
self.cache_processor = cache_processor
|
||||
|
||||
if (
|
||||
@ -225,43 +200,6 @@ class Cache(CacheBase):
|
||||
if self.cache_processor is not None:
|
||||
self.cache_processor.init(self, **kwargs)
|
||||
|
||||
def append_new_layers(self, layer_idx):
|
||||
"""
|
||||
Appends layers to the cache until the layer `layer_idx` is reached.
|
||||
Used in prefill and for skipped layers.
|
||||
"""
|
||||
while len(self.layers) <= layer_idx:
|
||||
self.layers.append(
|
||||
self.layer_classes[layer_idx % len(self.layer_classes)](self.config.for_layer(layer_idx))
|
||||
)
|
||||
|
||||
def _update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||
|
||||
Parameters:
|
||||
key_states (`torch.Tensor`):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
cache_kwargs (`dict[str, Any]`, `optional`):
|
||||
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
|
||||
cache to be created.
|
||||
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
self.append_new_layers(layer_idx)
|
||||
return self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
||||
@ -336,6 +274,46 @@ class Cache(CacheBase):
|
||||
else:
|
||||
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(layers={self.layers})"
|
||||
|
||||
def append_new_layers(self, layer_idx):
|
||||
"""
|
||||
Appends layers to the cache until the layer `layer_idx` is reached.
|
||||
Used in prefill and for skipped layers.
|
||||
"""
|
||||
while len(self.layers) <= layer_idx:
|
||||
self.layers.append(
|
||||
self.layer_classes[layer_idx % len(self.layer_classes)](self.config.for_layer(layer_idx))
|
||||
)
|
||||
|
||||
def _update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||
|
||||
Parameters:
|
||||
key_states (`torch.Tensor`):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
cache_kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
|
||||
cache to be created.
|
||||
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
self.append_new_layers(layer_idx)
|
||||
return self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
|
||||
|
||||
def get_seq_length(self, layer_idx: int = 0) -> int:
|
||||
"""Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position"""
|
||||
if layer_idx >= len(self.layers):
|
||||
@ -354,42 +332,25 @@ class Cache(CacheBase):
|
||||
"""
|
||||
return self.layers[layer_idx].get_mask_sizes(cache_position)
|
||||
|
||||
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
|
||||
backward compatibility."""
|
||||
legacy_cache = ()
|
||||
for layer in self.layers:
|
||||
if layer is not None:
|
||||
legacy_cache += ((layer.keys, layer.values),)
|
||||
return legacy_cache
|
||||
|
||||
@classmethod
|
||||
def from_legacy_cache(
|
||||
cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None
|
||||
) -> "Cache":
|
||||
"""Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
|
||||
backward compatibility."""
|
||||
cache = cls()
|
||||
if past_key_values is not None:
|
||||
for layer_idx in range(len(past_key_values)):
|
||||
key_states, value_states = past_key_values[layer_idx]
|
||||
cache.update(key_states, value_states, layer_idx)
|
||||
return cache
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(layers={self.layers})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""
|
||||
Base class for cache configs
|
||||
"""
|
||||
"""Base class for cache configs"""
|
||||
|
||||
def __init__(self, num_layers: Optional[int] = None, cache_implementation: Optional[str] = None):
|
||||
self.num_layers = num_layers
|
||||
self.cache_implementation = cache_implementation
|
||||
|
||||
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
|
||||
def __iter__(self):
|
||||
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
|
||||
for attr, value in copy.deepcopy(self.__dict__).items():
|
||||
yield attr, value
|
||||
|
||||
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
@classmethod
|
||||
def from_model_config(
|
||||
cls,
|
||||
@ -497,16 +458,6 @@ class CacheConfig:
|
||||
"""
|
||||
return copy.deepcopy(self.__dict__)
|
||||
|
||||
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
|
||||
def __iter__(self):
|
||||
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
|
||||
for attr, value in copy.deepcopy(self.__dict__).items():
|
||||
yield attr, value
|
||||
|
||||
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Serializes this instance to a JSON formatted string.
|
||||
@ -640,9 +591,7 @@ class QuantizedCacheConfig(CacheConfig):
|
||||
|
||||
@dataclass
|
||||
class StaticCacheConfig(CacheConfig):
|
||||
"""
|
||||
Configuration class for static and sliding window cache settings.
|
||||
"""
|
||||
"""Configuration class for static and sliding window cache settings."""
|
||||
|
||||
batch_size: Optional[int] = None
|
||||
max_cache_len: Optional[int] = None
|
||||
@ -669,9 +618,7 @@ class StaticCacheConfig(CacheConfig):
|
||||
logger.warning_once("`dtype` not set in cache initialization, using default `float32`")
|
||||
|
||||
def for_layer(self, layer_idx: int):
|
||||
"""
|
||||
Returns a StaticCacheConfig for a given layer index.
|
||||
"""
|
||||
"""Returns a StaticCacheConfig for a given layer index."""
|
||||
device = self.layer_device_map[layer_idx] if self.layer_device_map is not None else self.device
|
||||
return StaticCacheConfig(
|
||||
self.batch_size,
|
||||
@ -724,11 +671,19 @@ class StaticCacheConfig(CacheConfig):
|
||||
)
|
||||
|
||||
|
||||
class DynamicLayer(CacheLayer):
|
||||
class DynamicLayer(CacheLayerMixin):
|
||||
"""
|
||||
A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
|
||||
It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.
|
||||
"""
|
||||
keys, values = None, None
|
||||
|
||||
@classmethod
|
||||
def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> None:
|
||||
cache = cls()
|
||||
cache.keys = keys
|
||||
cache.values = values
|
||||
return cache
|
||||
|
||||
def update(
|
||||
self,
|
||||
@ -744,7 +699,7 @@ class DynamicLayer(CacheLayer):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
cache_kwargs (`dict[str, Any]`, `optional`):
|
||||
cache_kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`.
|
||||
|
||||
Return:
|
||||
@ -782,8 +737,10 @@ class DynamicLayer(CacheLayer):
|
||||
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
|
||||
|
||||
def crop(self, max_length: int) -> None:
|
||||
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
|
||||
negative to remove `max_length` tokens."""
|
||||
"""
|
||||
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
|
||||
negative to remove `max_length` tokens.
|
||||
"""
|
||||
if max_length < 0:
|
||||
max_length = self.get_seq_length() - abs(max_length)
|
||||
|
||||
@ -849,9 +806,35 @@ class DynamicCache(Cache):
|
||||
# compatibility. The name of the argument doesn't matter.
|
||||
if ddp_cache_data is not None:
|
||||
for key_states, value_states in ddp_cache_data:
|
||||
self.layers.append(DynamicLayer.from_kv(key_states, value_states))
|
||||
self.layers.append(DynamicLayer.from_tensors(key_states, value_states))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
|
||||
backward compatibility.
|
||||
"""
|
||||
legacy_cache = ()
|
||||
for layer in self.layers:
|
||||
if layer is not None:
|
||||
legacy_cache += ((layer.keys, layer.values),)
|
||||
return legacy_cache
|
||||
|
||||
@classmethod
|
||||
def from_legacy_cache(
|
||||
cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None
|
||||
) -> "Cache":
|
||||
"""
|
||||
Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
|
||||
backward compatibility.
|
||||
"""
|
||||
cache = cls()
|
||||
if past_key_values is not None:
|
||||
for layer_idx in range(len(past_key_values)):
|
||||
key_states, value_states = past_key_values[layer_idx]
|
||||
cache.update(key_states, value_states, layer_idx)
|
||||
return cache
|
||||
|
||||
|
||||
# Utilities for `DynamicCache` <> torch.export support
|
||||
def _flatten_dynamic_cache(
|
||||
@ -935,7 +918,7 @@ class OffloadedCache(DynamicCache):
|
||||
super().__init__(cache_processor=OffloadedCacheProcessor(), model_config=model_config)
|
||||
|
||||
|
||||
class StaticLayer(CacheLayer):
|
||||
class StaticLayer(CacheLayerMixin):
|
||||
is_compileable = True
|
||||
is_sliding = False
|
||||
|
||||
@ -1304,14 +1287,18 @@ class EncoderDecoderCache(Cache):
|
||||
|
||||
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
|
||||
def crop(self, maximum_length: int):
|
||||
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
|
||||
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
|
||||
"""
|
||||
Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
|
||||
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.
|
||||
"""
|
||||
self.check_dynamic_cache(self.crop.__name__)
|
||||
self.self_attention_cache.crop(maximum_length)
|
||||
|
||||
def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]":
|
||||
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
||||
`_split_model_inputs()` in `generation.utils`"""
|
||||
"""
|
||||
Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
||||
`_split_model_inputs()` in `generation.utils`
|
||||
"""
|
||||
self.check_dynamic_cache(self.batch_split.__name__)
|
||||
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
|
||||
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
|
||||
|
@ -350,10 +350,6 @@ class PhimoeFlashAttention2(PhimoeAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
|
@ -143,6 +143,9 @@ class ZambaHybridDynamicCache(Cache):
|
||||
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
||||
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.key_cache)
|
||||
|
||||
# Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update
|
||||
def update(
|
||||
self,
|
||||
|
@ -147,6 +147,9 @@ class Zamba2HybridDynamicCache(Cache):
|
||||
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
||||
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.key_cache)
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
|
@ -1626,7 +1626,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# 3.2. Decoder-only checks
|
||||
else:
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
|
||||
num_cache_decoder_layers = len(past_kv)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@ -1634,8 +1634,16 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
|
||||
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
|
||||
if is_legacy_cache:
|
||||
self_attention_layer_keys = past_kv[i][0]
|
||||
self_attention_layer_values = past_kv[i][1]
|
||||
elif past_kv.layers is None:
|
||||
# Cache is lot layered (i.e, Mamba derivatives)
|
||||
self_attention_layer_keys = past_kv.key_cache[i]
|
||||
self_attention_layer_values = past_kv.value_cache[i]
|
||||
else:
|
||||
self_attention_layer_keys = past_kv.layers[i].keys
|
||||
self_attention_layer_values = past_kv.layers[i].values
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user