mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Offloaded KV Cache (#31325)
* Initial implementation of OffloadedCache * enable usage via cache_implementation * Address feedback, add tests, remove legacy methods. * Remove flash-attn, discover synchronization bugs, fix bugs * Prevent usage in CPU only mode * Add a section about offloaded KV cache to the docs * Fix typos in docs * Clarifications and better explanation of streams
This commit is contained in:
parent
b4727a1216
commit
ca59d6f77c
@ -211,6 +211,80 @@ I like rock music because it's loud and energetic. It's a great way to express m
|
|||||||
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
|
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## KV Cache Offloading
|
||||||
|
|
||||||
|
Similarly to KV cache quantization, this strategy aims to reduce GPU VRAM usage.
|
||||||
|
It does so by moving the KV cache for most layers to the CPU.
|
||||||
|
As the model's `forward()` method iterates over the layers, this strategy maintains the current layer cache on the GPU.
|
||||||
|
At the same time it asynchronously prefetches the next layer cache as well as sending the previous layer cache back to the CPU.
|
||||||
|
Unlike KV cache quantization, this strategy always produces the same result as the default KV cache implementation.
|
||||||
|
Thus, it can serve as a drop-in replacement or a fallback for it.
|
||||||
|
|
||||||
|
Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.)
|
||||||
|
you may notice a small degradation in generation throughput compared to the default KV cache implementation.
|
||||||
|
|
||||||
|
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
>>> ckpt = "microsoft/Phi-3-mini-4k-instruct"
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0")
|
||||||
|
>>> inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23, cache_implementation="offloaded")
|
||||||
|
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
|
||||||
|
Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896.
|
||||||
|
|
||||||
|
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23)
|
||||||
|
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
|
||||||
|
Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896.
|
||||||
|
```
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
Cache offloading requires a GPU and can be slower than the default KV cache. Use it if you are getting CUDA out of memory errors.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
The example below shows how KV cache offloading can be used as a fallback strategy.
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
>>> def resilient_generate(model, *args, **kwargs):
|
||||||
|
... oom = False
|
||||||
|
... try:
|
||||||
|
... return model.generate(*args, **kwargs)
|
||||||
|
... except torch.cuda.OutOfMemoryError as e:
|
||||||
|
... print(e)
|
||||||
|
... print("retrying with cache_implementation='offloaded'")
|
||||||
|
... oom = True
|
||||||
|
... if oom:
|
||||||
|
... torch.cuda.empty_cache()
|
||||||
|
... kwargs["cache_implementation"] = "offloaded"
|
||||||
|
... return model.generate(*args, **kwargs)
|
||||||
|
...
|
||||||
|
...
|
||||||
|
>>> ckpt = "microsoft/Phi-3-mini-4k-instruct"
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0")
|
||||||
|
>>> prompt = ["okay "*1000 + "Fun fact: The most"]
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
|
>>> beams = { "num_beams": 40, "num_beam_groups": 40, "num_return_sequences": 40, "diversity_penalty": 1.0, "max_new_tokens": 23, "early_stopping": True, }
|
||||||
|
>>> out = resilient_generate(model, **inputs, **beams)
|
||||||
|
>>> responses = tokenizer.batch_decode(out[:,-28:], skip_special_tokens=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
On a GPU with 50 GB of RAM, running this code will print
|
||||||
|
```
|
||||||
|
CUDA out of memory. Tried to allocate 4.83 GiB. GPU
|
||||||
|
retrying with cache_implementation='offloaded'
|
||||||
|
```
|
||||||
|
before successfully generating 40 beams.
|
||||||
|
|
||||||
|
|
||||||
## Watermarking
|
## Watermarking
|
||||||
|
|
||||||
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
|
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
|
||||||
|
@ -450,6 +450,118 @@ class DynamicCache(Cache):
|
|||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadedCache(DynamicCache):
|
||||||
|
"""
|
||||||
|
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
|
||||||
|
Useful for generating from models with very long context.
|
||||||
|
|
||||||
|
In addition to the default CUDA stream, where all forward() computations happen,
|
||||||
|
this class uses another stream, the prefetch stream, which it creates itself.
|
||||||
|
Since scheduling of operations on separate streams happens independently, this class uses
|
||||||
|
the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
|
||||||
|
The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
|
||||||
|
ensure the eviction is scheduled after all computations on that cache are finished.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("OffloadedCache can only be used with a GPU")
|
||||||
|
super().__init__()
|
||||||
|
self.original_device = []
|
||||||
|
self.prefetch_stream = torch.cuda.Stream()
|
||||||
|
self.beam_idx = None # used to delay beam search operations
|
||||||
|
|
||||||
|
def prefetch_layer(self, layer_idx: int):
|
||||||
|
"Starts prefetching the next layer cache"
|
||||||
|
if layer_idx < len(self):
|
||||||
|
with torch.cuda.stream(self.prefetch_stream):
|
||||||
|
# Prefetch next layer tensors to GPU
|
||||||
|
device = self.original_device[layer_idx]
|
||||||
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
||||||
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
|
||||||
|
|
||||||
|
def evict_previous_layer(self, layer_idx: int):
|
||||||
|
"Moves the previous layer cache to the CPU"
|
||||||
|
if len(self) > 2:
|
||||||
|
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
|
||||||
|
prev_layer_idx = (layer_idx - 1) % len(self)
|
||||||
|
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
||||||
|
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
||||||
|
|
||||||
|
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||||
|
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
||||||
|
if layer_idx < len(self):
|
||||||
|
# Evict the previous layer if necessary
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.evict_previous_layer(layer_idx)
|
||||||
|
# Load current layer cache to its original device if not already there
|
||||||
|
original_device = self.original_device[layer_idx]
|
||||||
|
self.prefetch_stream.synchronize()
|
||||||
|
key_tensor = self.key_cache[layer_idx]
|
||||||
|
value_tensor = self.value_cache[layer_idx]
|
||||||
|
# Now deal with beam search ops which were delayed
|
||||||
|
if self.beam_idx is not None:
|
||||||
|
self.beam_idx = self.beam_idx.to(original_device)
|
||||||
|
key_tensor = key_tensor.index_select(0, self.beam_idx)
|
||||||
|
value_tensor = value_tensor.index_select(0, self.beam_idx)
|
||||||
|
# Prefetch the next layer
|
||||||
|
self.prefetch_layer((layer_idx + 1) % len(self))
|
||||||
|
return (key_tensor, value_tensor)
|
||||||
|
else:
|
||||||
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
||||||
|
|
||||||
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||||
|
"""Saves the beam indices and reorders the cache when the tensor is back to its device."""
|
||||||
|
# We delay this operation until the tensors are back to their original
|
||||||
|
# device because performing torch.index_select on the CPU is very slow
|
||||||
|
del self.beam_idx
|
||||||
|
self.beam_idx = beam_idx.clone()
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||||
|
Parameters:
|
||||||
|
key_states (`torch.Tensor`):
|
||||||
|
The new key states to cache.
|
||||||
|
value_states (`torch.Tensor`):
|
||||||
|
The new value states to cache.
|
||||||
|
layer_idx (`int`):
|
||||||
|
The index of the layer to cache the states for.
|
||||||
|
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||||
|
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
|
||||||
|
Return:
|
||||||
|
A tuple containing the updated key and value states.
|
||||||
|
"""
|
||||||
|
# Update the number of seen tokens
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.key_cache) <= layer_idx:
|
||||||
|
self.key_cache.append(key_states)
|
||||||
|
self.value_cache.append(value_states)
|
||||||
|
self.original_device.append(key_states.device)
|
||||||
|
self.evict_previous_layer(layer_idx)
|
||||||
|
else:
|
||||||
|
key_tensor, value_tensor = self[layer_idx]
|
||||||
|
self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
|
||||||
|
self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
|
||||||
|
|
||||||
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||||
|
|
||||||
|
# According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
|
||||||
|
# if a method is not supposed to be supported in a subclass we should set it to None
|
||||||
|
from_legacy_cache = None
|
||||||
|
|
||||||
|
to_legacy_cache = None
|
||||||
|
|
||||||
|
|
||||||
class QuantizedCache(DynamicCache):
|
class QuantizedCache(DynamicCache):
|
||||||
"""
|
"""
|
||||||
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
|
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
|
||||||
|
@ -32,6 +32,7 @@ from ..cache_utils import (
|
|||||||
HQQQuantizedCache,
|
HQQQuantizedCache,
|
||||||
HybridCache,
|
HybridCache,
|
||||||
MambaCache,
|
MambaCache,
|
||||||
|
OffloadedCache,
|
||||||
QuantizedCacheConfig,
|
QuantizedCacheConfig,
|
||||||
QuantoQuantizedCache,
|
QuantoQuantizedCache,
|
||||||
SlidingWindowCache,
|
SlidingWindowCache,
|
||||||
@ -1842,6 +1843,8 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_kwargs[cache_name] = cache_class(cache_config)
|
model_kwargs[cache_name] = cache_class(cache_config)
|
||||||
|
elif generation_config.cache_implementation == "offloaded":
|
||||||
|
model_kwargs[cache_name] = OffloadedCache()
|
||||||
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
||||||
# keeps copying the cache thus using much more memory
|
# keeps copying the cache thus using much more memory
|
||||||
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
|
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
|
||||||
|
@ -38,6 +38,7 @@ if is_torch_available():
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
|
GenerationConfig,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
LlamaConfig,
|
LlamaConfig,
|
||||||
SinkCache,
|
SinkCache,
|
||||||
@ -513,3 +514,54 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
@unittest.skip(reason="TODO @gante static cache's does not support beam search yet")
|
@unittest.skip(reason="TODO @gante static cache's does not support beam search yet")
|
||||||
def test_static_cache_beam_search(self):
|
def test_static_cache_beam_search(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
|
||||||
|
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
|
||||||
|
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||||
|
device = model.device
|
||||||
|
input_text = "Fun fact:"
|
||||||
|
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
||||||
|
common = {
|
||||||
|
"num_beams": 4,
|
||||||
|
"num_beam_groups": 2,
|
||||||
|
"num_return_sequences": 4,
|
||||||
|
"diversity_penalty": 1.0,
|
||||||
|
"max_new_tokens": 20,
|
||||||
|
"early_stopping": True,
|
||||||
|
}
|
||||||
|
original = GenerationConfig(**common)
|
||||||
|
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
||||||
|
original_outputs = model.generate(generation_config=original, **inputs)
|
||||||
|
offloaded_outputs = model.generate(generation_config=offloaded, **inputs)
|
||||||
|
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
|
||||||
|
assert torch.all(original_output == offloaded_output).item()
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
|
||||||
|
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
|
||||||
|
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||||
|
device = model.device
|
||||||
|
input_text = "Fun fact:"
|
||||||
|
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
||||||
|
common = {
|
||||||
|
"num_beams": 4,
|
||||||
|
"num_beam_groups": 2,
|
||||||
|
"num_return_sequences": 4,
|
||||||
|
"diversity_penalty": 1.0,
|
||||||
|
"max_new_tokens": 20,
|
||||||
|
"early_stopping": True,
|
||||||
|
}
|
||||||
|
original = GenerationConfig(**common)
|
||||||
|
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
||||||
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
|
model.generate(generation_config=original, **inputs)
|
||||||
|
original_peak_memory = torch.cuda.max_memory_allocated(device)
|
||||||
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
|
model.generate(generation_config=offloaded, **inputs)
|
||||||
|
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
|
||||||
|
assert offloaded_peak_memory < original_peak_memory
|
||||||
|
Loading…
Reference in New Issue
Block a user