mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
[generate] move SinkCache
to a custom_generate
repo (#38399)
remove sink cache
This commit is contained in:
parent
fe5bfaa4b5
commit
beaed8ce01
@ -380,11 +380,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
[[autodoc]] HQQQuantizedCache
|
||||
|
||||
[[autodoc]] SinkCache
|
||||
- update
|
||||
- get_seq_length
|
||||
- reorder_cache
|
||||
|
||||
[[autodoc]] OffloadedCache
|
||||
- update
|
||||
- prefetch_layer
|
||||
@ -443,4 +438,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
[[autodoc]] CompileConfig
|
||||
- __call__
|
||||
|
||||
|
@ -30,7 +30,6 @@ Transformers offers several [`Cache`] classes that implement different caching m
|
||||
| Offloaded Static Cache | No | Yes | Yes | High | Yes |
|
||||
| Quantized Cache | Yes | No | No | Low | Yes |
|
||||
| Sliding Window Cache | No | Yes | Yes | High | No |
|
||||
| Sink Cache | Yes | No | Yes | Mid | Yes |
|
||||
|
||||
This guide introduces you to the different [`Cache`] classes and shows you how to use them for generation.
|
||||
|
||||
@ -174,28 +173,6 @@ I like rock music because it's loud and energetic. It's a great way to express m
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Sink cache
|
||||
|
||||
[`SinkCache`] is capable of generating very long sequences ("infinite length" according to the paper) by only retaining a few initial tokens from the sequence. These are called the *sink tokens* because they account for a significant portion of the attention scores during generation. Subsequent tokens are discarded on a sliding windowed basis, and only the latest `window_size` tokens are kept. This means most of the previous knowledge is discarded.
|
||||
|
||||
The sink tokens allow a model to maintain stable performance even when it's dealing with very long text sequences.
|
||||
|
||||
Enable [`SinkCache`] by initializing it first with the [window_length](https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.window_length) and [num_sink_tokens](https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.num_sink_tokens) parameters before passing it to [past_key_values](https://hf.co/docs/transformers/internal/generation_utils#transformers.generation.GenerateDecoderOnlyOutput.past_key_values) in [`~GenerationMixin.generate`].
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
|
||||
inputs = tokenizer("This is a long story about unicorns, fairies and magic.", return_tensors="pt").to(model.device)
|
||||
|
||||
past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
|
||||
out = model.generate(**inputs, do_sample=False, max_new_tokens=30, past_key_values=past_key_values)
|
||||
tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
"This is a long story about unicorns, fairies and magic. It is a fantasy world where unicorns and fairies live together in harmony. The story follows a young girl named Lily"
|
||||
```
|
||||
|
||||
## Speed optimized caches
|
||||
|
||||
The default [`DynamicCache`] prevents you from taking advantage of just-in-time (JIT) optimizations because the cache size isn't fixed. JIT optimizations enable you to maximize latency at the expense of memory usage. All of the following cache types are compatible with JIT optimizations like [torch.compile](./llm_optims#static-kv-cache-and-torchcompile) to accelerate generation.
|
||||
@ -247,7 +224,7 @@ Enable [`SlidingWindowCache`] by configuring `cache_implementation="sliding_wind
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16).to("cuda:0")
|
||||
@ -284,8 +261,6 @@ A cache can also work in iterative generation settings where there is back-and-f
|
||||
|
||||
For iterative generation with a cache, start by initializing an empty cache class and then you can feed in your new prompts. Keep track of dialogue history with a [chat template](./chat_templating).
|
||||
|
||||
If you're using [`SinkCache`], the inputs need to be truncated to the maximum length because [`SinkCache`] can generate text that exceeds its maximum window size. However, the first input shouldn't exceed the maximum cache length.
|
||||
|
||||
The example below demonstrates how to use a cache for iterative generation.
|
||||
|
||||
```py
|
||||
@ -293,7 +268,6 @@ import torch
|
||||
from transformers import AutoTokenizer,AutoModelForCausalLM
|
||||
from transformers.cache_utils import (
|
||||
DynamicCache,
|
||||
SinkCache,
|
||||
StaticCache,
|
||||
SlidingWindowCache,
|
||||
QuantoQuantizedCache,
|
||||
@ -313,8 +287,6 @@ messages = []
|
||||
for prompt in user_prompts:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
|
||||
if isinstance(past_key_values, SinkCache):
|
||||
inputs = {k: v[:, -max_cache_length:] for k, v in inputs.items()}
|
||||
input_length = inputs["input_ids"].shape[1]
|
||||
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=256, past_key_values=past_key_values)
|
||||
completion = tokenizer.decode(outputs[0, input_length: ], skip_special_tokens=True)
|
||||
@ -336,7 +308,7 @@ model_id = "meta-llama/Llama-2-7b-chat-hf"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Init StaticCache with big enough max-length (1024 tokens for the below example)
|
||||
# Init StaticCache with big enough max-length (1024 tokens for the below example)
|
||||
# You can also init a DynamicCache, if that suits you better
|
||||
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
@ -351,7 +323,7 @@ responses = []
|
||||
for prompt in prompts:
|
||||
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
|
||||
past_key_values = copy.deepcopy(prompt_cache)
|
||||
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
|
||||
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
|
||||
response = tokenizer.batch_decode(outputs)[0]
|
||||
responses.append(response)
|
||||
|
||||
|
@ -366,11 +366,6 @@ generation_output[:2]
|
||||
|
||||
[[autodoc]] HQQQuantizedCache
|
||||
|
||||
[[autodoc]] SinkCache
|
||||
- update
|
||||
- get_seq_length
|
||||
- reorder_cache
|
||||
|
||||
[[autodoc]] OffloadedCache
|
||||
- update
|
||||
- prefetch_layer
|
||||
|
@ -2,7 +2,6 @@ import copy
|
||||
import importlib.metadata
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@ -1063,199 +1062,18 @@ class HQQQuantizedCache(QuantizedCache):
|
||||
|
||||
class SinkCache(Cache):
|
||||
"""
|
||||
Deprecated.
|
||||
|
||||
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
||||
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
|
||||
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
|
||||
|
||||
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
||||
`[batch_size, num_heads, seq_len, head_dim]`.
|
||||
|
||||
Parameters:
|
||||
window_length (`int`):
|
||||
The length of the context window.
|
||||
num_sink_tokens (`int`):
|
||||
The number of sink tokens. See the original paper for more information.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
|
||||
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
||||
|
||||
>>> # Prepare a cache class and pass it to model's forward
|
||||
>>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
|
||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
>>> outputs.past_key_values # access cache filled with key/values from generation
|
||||
SinkCache()
|
||||
```
|
||||
Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache.
|
||||
See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for
|
||||
general `custom_generate`usage.
|
||||
"""
|
||||
|
||||
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
||||
super().__init__()
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
self.window_length = window_length
|
||||
self.num_sink_tokens = num_sink_tokens
|
||||
self.cos_sin_rerotation_cache = {}
|
||||
self._cos_cache = None
|
||||
self._sin_cache = None
|
||||
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
||||
|
||||
warnings.warn(
|
||||
"`SinkCache` is deprecated and will be removed in v4.53.0. You can achieve similar functionality by "
|
||||
"using a model with a sliding window attention mechanism, or by expanding RoPE and optionally using an "
|
||||
"offloaded cache implementation.",
|
||||
FutureWarning,
|
||||
# TODO (joao, manuel): Remove this class in v4.59.0
|
||||
def __init__(self, **kwargs) -> None:
|
||||
raise NotImplementedError(
|
||||
"`SinkCache` has been moved as a `custom_generate` repository on the Hub: "
|
||||
"https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_key_rotary_pos_emb(
|
||||
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
||||
return rotated_key_states
|
||||
|
||||
def _get_rerotation_cos_sin(
|
||||
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
|
||||
# Upcast to float32 temporarily for better accuracy
|
||||
cos = cos.to(torch.float32)
|
||||
sin = sin.to(torch.float32)
|
||||
|
||||
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
||||
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
||||
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
||||
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
||||
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
||||
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
||||
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
||||
|
||||
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
|
||||
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
||||
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
||||
)
|
||||
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||
# TODO: deprecate this function in favor of `cache_position`
|
||||
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
return 0
|
||||
return self.key_cache[layer_idx].shape[-2]
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
|
||||
return self.window_length
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||
|
||||
Parameters:
|
||||
key_states (`torch.Tensor`):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
|
||||
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
|
||||
rotation as the tokens are shifted.
|
||||
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
||||
# with partially rotated position embeddings, like Phi or Persimmon.
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
sin = cache_kwargs.get("sin")
|
||||
cos = cache_kwargs.get("cos")
|
||||
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
||||
using_rope = cos is not None and sin is not None
|
||||
|
||||
# Update the number of seen tokens
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
||||
if using_rope and layer_idx == 0:
|
||||
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
||||
# after all RoPE models have a llama-like cache utilization.
|
||||
if cos.dim() == 2:
|
||||
self._cos_cache = cos
|
||||
self._sin_cache = sin
|
||||
else:
|
||||
if self._cos_cache is None:
|
||||
self._cos_cache = cos[0, ...]
|
||||
self._sin_cache = sin[0, ...]
|
||||
elif self._cos_cache.shape[0] < self.window_length:
|
||||
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
|
||||
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
|
||||
|
||||
# [bsz, num_heads, seq_len, head_dim]
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
# Empty cache
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
|
||||
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
||||
# Growing cache
|
||||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||
|
||||
else:
|
||||
# Shifting cache
|
||||
keys_to_keep = self.key_cache[layer_idx][
|
||||
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
|
||||
]
|
||||
|
||||
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
||||
if using_rope:
|
||||
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
||||
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
|
||||
)
|
||||
if partial_rotation_size is not None:
|
||||
keys_to_keep, keys_pass = (
|
||||
keys_to_keep[..., :partial_rotation_size],
|
||||
keys_to_keep[..., partial_rotation_size:],
|
||||
)
|
||||
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
||||
if partial_rotation_size is not None:
|
||||
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
||||
|
||||
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
||||
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||||
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
||||
|
||||
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||||
values_to_keep = self.value_cache[layer_idx][
|
||||
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
||||
]
|
||||
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
||||
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
|
||||
class StaticCache(Cache):
|
||||
"""
|
||||
|
@ -34,7 +34,7 @@ from torch.fx._symbolic_trace import is_fx_tracing
|
||||
from torch.fx.proxy import ParameterProxy
|
||||
|
||||
from .. import logging
|
||||
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
|
||||
from ..cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ..modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from ..models.auto import get_values
|
||||
from ..models.auto.modeling_auto import (
|
||||
@ -832,12 +832,6 @@ ProxyableDynamicCache = HFProxyableClassMeta(
|
||||
{},
|
||||
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
|
||||
)
|
||||
ProxyableSinkCache = HFProxyableClassMeta(
|
||||
"ProxyableSinkCache",
|
||||
(SinkCache,),
|
||||
{},
|
||||
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
|
||||
)
|
||||
ProxyableStaticCache = HFProxyableClassMeta(
|
||||
"ProxyableStaticCache",
|
||||
(StaticCache,),
|
||||
@ -880,7 +874,6 @@ class HFTracer(Tracer):
|
||||
_CLASSES_TO_PATCH = {
|
||||
Cache: ProxyableCache,
|
||||
DynamicCache: ProxyableDynamicCache,
|
||||
SinkCache: ProxyableSinkCache,
|
||||
StaticCache: ProxyableStaticCache,
|
||||
}
|
||||
|
||||
|
@ -1050,6 +1050,7 @@ UNDOCUMENTED_OBJECTS = [
|
||||
"VitPoseBackbone", # Internal module
|
||||
"VitPoseBackboneConfig", # Internal module
|
||||
"get_values", # Internal object
|
||||
"SinkCache", # Moved to a custom_generate repository, to be deleted from transformers in v4.59.0
|
||||
]
|
||||
|
||||
# This list should be empty. Objects in it should get their own doc page.
|
||||
|
Loading…
Reference in New Issue
Block a user