Offloaded KV Cache (#31325)

* Initial implementation of OffloadedCache

* enable usage via cache_implementation

* Address feedback, add tests, remove legacy methods.

* Remove flash-attn, discover synchronization bugs, fix bugs

* Prevent usage in CPU only mode

* Add a section about offloaded KV cache to the docs

* Fix typos in docs

* Clarifications and better explanation of streams
This commit is contained in:
Nikos Karampatziakis 2024-08-01 05:42:07 -07:00 committed by GitHub
parent b4727a1216
commit ca59d6f77c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 241 additions and 0 deletions

View File

@ -211,6 +211,80 @@ I like rock music because it's loud and energetic. It's a great way to express m
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```
## KV Cache Offloading
Similarly to KV cache quantization, this strategy aims to reduce GPU VRAM usage.
It does so by moving the KV cache for most layers to the CPU.
As the model's `forward()` method iterates over the layers, this strategy maintains the current layer cache on the GPU.
At the same time it asynchronously prefetches the next layer cache as well as sending the previous layer cache back to the CPU.
Unlike KV cache quantization, this strategy always produces the same result as the default KV cache implementation.
Thus, it can serve as a drop-in replacement or a fallback for it.
Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.)
you may notice a small degradation in generation throughput compared to the default KV cache implementation.
To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config`.
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> ckpt = "microsoft/Phi-3-mini-4k-instruct"
>>> tokenizer = AutoTokenizer.from_pretrained(ckpt)
>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0")
>>> inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23, cache_implementation="offloaded")
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896.
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23)
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896.
```
<Tip warning={true}>
Cache offloading requires a GPU and can be slower than the default KV cache. Use it if you are getting CUDA out of memory errors.
</Tip>
The example below shows how KV cache offloading can be used as a fallback strategy.
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> def resilient_generate(model, *args, **kwargs):
... oom = False
... try:
... return model.generate(*args, **kwargs)
... except torch.cuda.OutOfMemoryError as e:
... print(e)
... print("retrying with cache_implementation='offloaded'")
... oom = True
... if oom:
... torch.cuda.empty_cache()
... kwargs["cache_implementation"] = "offloaded"
... return model.generate(*args, **kwargs)
...
...
>>> ckpt = "microsoft/Phi-3-mini-4k-instruct"
>>> tokenizer = AutoTokenizer.from_pretrained(ckpt)
>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0")
>>> prompt = ["okay "*1000 + "Fun fact: The most"]
>>> inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
>>> beams = { "num_beams": 40, "num_beam_groups": 40, "num_return_sequences": 40, "diversity_penalty": 1.0, "max_new_tokens": 23, "early_stopping": True, }
>>> out = resilient_generate(model, **inputs, **beams)
>>> responses = tokenizer.batch_decode(out[:,-28:], skip_special_tokens=True)
```
On a GPU with 50 GB of RAM, running this code will print
```
CUDA out of memory. Tried to allocate 4.83 GiB. GPU
retrying with cache_implementation='offloaded'
```
before successfully generating 40 beams.
## Watermarking
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".

View File

@ -450,6 +450,118 @@ class DynamicCache(Cache):
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
class OffloadedCache(DynamicCache):
"""
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
Useful for generating from models with very long context.
In addition to the default CUDA stream, where all forward() computations happen,
this class uses another stream, the prefetch stream, which it creates itself.
Since scheduling of operations on separate streams happens independently, this class uses
the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
ensure the eviction is scheduled after all computations on that cache are finished.
"""
def __init__(self) -> None:
if not torch.cuda.is_available():
raise RuntimeError("OffloadedCache can only be used with a GPU")
super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
self.beam_idx = None # used to delay beam search operations
def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache"
if layer_idx < len(self):
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
if len(self) > 2:
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
prev_layer_idx = (layer_idx - 1) % len(self)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
# Evict the previous layer if necessary
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx]
self.prefetch_stream.synchronize()
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
# Now deal with beam search ops which were delayed
if self.beam_idx is not None:
self.beam_idx = self.beam_idx.to(original_device)
key_tensor = key_tensor.index_select(0, self.beam_idx)
value_tensor = value_tensor.index_select(0, self.beam_idx)
# Prefetch the next layer
self.prefetch_layer((layer_idx + 1) % len(self))
return (key_tensor, value_tensor)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Saves the beam indices and reorders the cache when the tensor is back to its device."""
# We delay this operation until the tensors are back to their original
# device because performing torch.index_select on the CPU is very slow
del self.beam_idx
self.beam_idx = beam_idx.clone()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.original_device.append(key_states.device)
self.evict_previous_layer(layer_idx)
else:
key_tensor, value_tensor = self[layer_idx]
self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
# According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
# if a method is not supposed to be supported in a subclass we should set it to None
from_legacy_cache = None
to_legacy_cache = None
class QuantizedCache(DynamicCache):
"""
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).

View File

@ -32,6 +32,7 @@ from ..cache_utils import (
HQQQuantizedCache,
HybridCache,
MambaCache,
OffloadedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
@ -1842,6 +1843,8 @@ class GenerationMixin:
)
model_kwargs[cache_name] = cache_class(cache_config)
elif generation_config.cache_implementation == "offloaded":
model_kwargs[cache_name] = OffloadedCache()
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():

View File

@ -38,6 +38,7 @@ if is_torch_available():
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
GenerationConfig,
GPT2LMHeadModel,
LlamaConfig,
SinkCache,
@ -513,3 +514,54 @@ class CacheIntegrationTest(unittest.TestCase):
@unittest.skip(reason="TODO @gante static cache's does not support beam search yet")
def test_static_cache_beam_search(self):
pass
@require_torch_gpu
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
"num_beams": 4,
"num_beam_groups": 2,
"num_return_sequences": 4,
"diversity_penalty": 1.0,
"max_new_tokens": 20,
"early_stopping": True,
}
original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
original_outputs = model.generate(generation_config=original, **inputs)
offloaded_outputs = model.generate(generation_config=offloaded, **inputs)
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
assert torch.all(original_output == offloaded_output).item()
@require_torch_gpu
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
"num_beams": 4,
"num_beam_groups": 2,
"num_return_sequences": 4,
"diversity_penalty": 1.0,
"max_new_tokens": 20,
"early_stopping": True,
}
original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
torch.cuda.reset_peak_memory_stats(device)
model.generate(generation_config=original, **inputs)
original_peak_memory = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
model.generate(generation_config=offloaded, **inputs)
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
assert offloaded_peak_memory < original_peak_memory