diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 68430de643f..1f4df78b9a6 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -211,6 +211,80 @@ I like rock music because it's loud and energetic. It's a great way to express m I like rock music because it's loud and energetic. I like to listen to it when I'm feeling ``` +## KV Cache Offloading + +Similarly to KV cache quantization, this strategy aims to reduce GPU VRAM usage. +It does so by moving the KV cache for most layers to the CPU. +As the model's `forward()` method iterates over the layers, this strategy maintains the current layer cache on the GPU. +At the same time it asynchronously prefetches the next layer cache as well as sending the previous layer cache back to the CPU. +Unlike KV cache quantization, this strategy always produces the same result as the default KV cache implementation. +Thus, it can serve as a drop-in replacement or a fallback for it. + +Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.) +you may notice a small degradation in generation throughput compared to the default KV cache implementation. + +To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config`. + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM +>>> ckpt = "microsoft/Phi-3-mini-4k-instruct" + +>>> tokenizer = AutoTokenizer.from_pretrained(ckpt) +>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0") +>>> inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device) + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23, cache_implementation="offloaded") +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896. + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23) +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896. +``` + + + +Cache offloading requires a GPU and can be slower than the default KV cache. Use it if you are getting CUDA out of memory errors. + + + +The example below shows how KV cache offloading can be used as a fallback strategy. +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM +>>> def resilient_generate(model, *args, **kwargs): +... oom = False +... try: +... return model.generate(*args, **kwargs) +... except torch.cuda.OutOfMemoryError as e: +... print(e) +... print("retrying with cache_implementation='offloaded'") +... oom = True +... if oom: +... torch.cuda.empty_cache() +... kwargs["cache_implementation"] = "offloaded" +... return model.generate(*args, **kwargs) +... +... +>>> ckpt = "microsoft/Phi-3-mini-4k-instruct" +>>> tokenizer = AutoTokenizer.from_pretrained(ckpt) +>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0") +>>> prompt = ["okay "*1000 + "Fun fact: The most"] +>>> inputs = tokenizer(prompt, return_tensors="pt").to(model.device) +>>> beams = { "num_beams": 40, "num_beam_groups": 40, "num_return_sequences": 40, "diversity_penalty": 1.0, "max_new_tokens": 23, "early_stopping": True, } +>>> out = resilient_generate(model, **inputs, **beams) +>>> responses = tokenizer.batch_decode(out[:,-28:], skip_special_tokens=True) +``` + +On a GPU with 50 GB of RAM, running this code will print +``` +CUDA out of memory. Tried to allocate 4.83 GiB. GPU +retrying with cache_implementation='offloaded' +``` +before successfully generating 40 beams. + + ## Watermarking The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green". diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index bfe13e27af7..d9a3a3a5a50 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -450,6 +450,118 @@ class DynamicCache(Cache): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +class OffloadedCache(DynamicCache): + """ + A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. + Useful for generating from models with very long context. + + In addition to the default CUDA stream, where all forward() computations happen, + this class uses another stream, the prefetch stream, which it creates itself. + Since scheduling of operations on separate streams happens independently, this class uses + the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. + The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to + ensure the eviction is scheduled after all computations on that cache are finished. + """ + + def __init__(self) -> None: + if not torch.cuda.is_available(): + raise RuntimeError("OffloadedCache can only be used with a GPU") + super().__init__() + self.original_device = [] + self.prefetch_stream = torch.cuda.Stream() + self.beam_idx = None # used to delay beam search operations + + def prefetch_layer(self, layer_idx: int): + "Starts prefetching the next layer cache" + if layer_idx < len(self): + with torch.cuda.stream(self.prefetch_stream): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) + + def evict_previous_layer(self, layer_idx: int): + "Moves the previous layer cache to the CPU" + if len(self) > 2: + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(self) + self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) + self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." + if layer_idx < len(self): + # Evict the previous layer if necessary + torch.cuda.current_stream().synchronize() + self.evict_previous_layer(layer_idx) + # Load current layer cache to its original device if not already there + original_device = self.original_device[layer_idx] + self.prefetch_stream.synchronize() + key_tensor = self.key_cache[layer_idx] + value_tensor = self.value_cache[layer_idx] + # Now deal with beam search ops which were delayed + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(original_device) + key_tensor = key_tensor.index_select(0, self.beam_idx) + value_tensor = value_tensor.index_select(0, self.beam_idx) + # Prefetch the next layer + self.prefetch_layer((layer_idx + 1) % len(self)) + return (key_tensor, value_tensor) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Saves the beam indices and reorders the cache when the tensor is back to its device.""" + # We delay this operation until the tensors are back to their original + # device because performing torch.index_select on the CPU is very slow + del self.beam_idx + self.beam_idx = beam_idx.clone() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.original_device.append(key_states.device) + self.evict_previous_layer(layer_idx) + else: + key_tensor, value_tensor = self[layer_idx] + self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError + # if a method is not supposed to be supported in a subclass we should set it to None + from_legacy_cache = None + + to_legacy_cache = None + + class QuantizedCache(DynamicCache): """ A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c9683791cb9..e7736537e0c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -32,6 +32,7 @@ from ..cache_utils import ( HQQQuantizedCache, HybridCache, MambaCache, + OffloadedCache, QuantizedCacheConfig, QuantoQuantizedCache, SlidingWindowCache, @@ -1842,6 +1843,8 @@ class GenerationMixin: ) model_kwargs[cache_name] = cache_class(cache_config) + elif generation_config.cache_implementation == "offloaded": + model_kwargs[cache_name] = OffloadedCache() # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 74dc5951ee9..2729a2989ab 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -38,6 +38,7 @@ if is_torch_available(): AutoModelForCausalLM, AutoTokenizer, DynamicCache, + GenerationConfig, GPT2LMHeadModel, LlamaConfig, SinkCache, @@ -513,3 +514,54 @@ class CacheIntegrationTest(unittest.TestCase): @unittest.skip(reason="TODO @gante static cache's does not support beam search yet") def test_static_cache_beam_search(self): pass + + @require_torch_gpu + def test_offloaded_cache_equivalent_to_dynamic_cache(self): + """Tests that OffloadedCache produces the same result as the default DynamicCache""" + model_name = "microsoft/Phi-3-mini-4k-instruct" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) + device = model.device + input_text = "Fun fact:" + inputs = tokenizer(input_text, return_tensors="pt").to(device) + common = { + "num_beams": 4, + "num_beam_groups": 2, + "num_return_sequences": 4, + "diversity_penalty": 1.0, + "max_new_tokens": 20, + "early_stopping": True, + } + original = GenerationConfig(**common) + offloaded = GenerationConfig(cache_implementation="offloaded", **common) + original_outputs = model.generate(generation_config=original, **inputs) + offloaded_outputs = model.generate(generation_config=offloaded, **inputs) + for original_output, offloaded_output in zip(original_outputs, offloaded_outputs): + assert torch.all(original_output == offloaded_output).item() + + @require_torch_gpu + def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self): + """Tests that OffloadedCache uses less memory than the default DynamicCache""" + model_name = "microsoft/Phi-3-mini-4k-instruct" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) + device = model.device + input_text = "Fun fact:" + inputs = tokenizer(input_text, return_tensors="pt").to(device) + common = { + "num_beams": 4, + "num_beam_groups": 2, + "num_return_sequences": 4, + "diversity_penalty": 1.0, + "max_new_tokens": 20, + "early_stopping": True, + } + original = GenerationConfig(**common) + offloaded = GenerationConfig(cache_implementation="offloaded", **common) + torch.cuda.reset_peak_memory_stats(device) + model.generate(generation_config=original, **inputs) + original_peak_memory = torch.cuda.max_memory_allocated(device) + torch.cuda.reset_peak_memory_stats(device) + model.generate(generation_config=offloaded, **inputs) + offloaded_peak_memory = torch.cuda.max_memory_allocated(device) + assert offloaded_peak_memory < original_peak_memory