enable OffloadedCache on XPU from PyTorch 2.7 (#36654)

* fix "Cannot copy out of meta tensor; no data!" issue for BartForConditionalGeneration model

* follow Marc's suggestion to use _tie_weights to fix

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* enable OffloadedCache on XPU since PyTorch 2.7

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* don't change bart

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* make code more concise per review comments

Signed-off-by: N <matrix.yao@intel.com>

* fix review comments

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* Revert "fix review comments"

This reverts commit acf1484b86.

* fix review comments

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* fix style

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

---------

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
Signed-off-by: N <matrix.yao@intel.com>
Co-authored-by: root <root@a4bf01945cfe.jf.intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Yao Matrix 2025-03-19 22:15:52 +08:00 committed by GitHub
parent e8d960329e
commit b11050d6a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 18 deletions

View File

@ -9,7 +9,7 @@ import torch
from packaging import version
from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_optimum_quanto_available, logging
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
if is_hqq_available():
@ -537,10 +537,10 @@ class DynamicCache(Cache):
class OffloadedCache(DynamicCache):
"""
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory.
Useful for generating from models with very long context.
In addition to the default CUDA stream, where all forward() computations happen,
In addition to the default accelerator stream, where all forward() computations happen,
this class uses another stream, the prefetch stream, which it creates itself.
Since scheduling of operations on separate streams happens independently, this class uses
the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
@ -549,17 +549,21 @@ class OffloadedCache(DynamicCache):
"""
def __init__(self) -> None:
if not torch.cuda.is_available():
raise RuntimeError("OffloadedCache can only be used with a GPU")
if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())):
raise RuntimeError(
"OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "")
)
super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
self.prefetch_stream = None
self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream()
self.beam_idx = None # used to delay beam search operations
def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache"
if layer_idx < len(self):
with torch.cuda.stream(self.prefetch_stream):
with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
@ -577,7 +581,10 @@ class OffloadedCache(DynamicCache):
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
# Evict the previous layer if necessary
torch.cuda.current_stream().synchronize()
if is_torch_greater_or_equal("2.7"):
torch.accelerator.current_stream().synchronize()
else:
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx]

View File

@ -1062,7 +1062,9 @@ def is_torch_greater_or_equal(library_version: str):
if not _is_package_available("torch"):
return False
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse(
library_version
)
def is_torchdistx_available():

View File

@ -27,6 +27,7 @@ from transformers.testing_utils import (
require_non_xpu,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_gpu,
slow,
@ -48,7 +49,7 @@ if is_torch_available():
StaticCache,
convert_and_export_with_cache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
from transformers.utils import is_torch_greater_or_equal
@require_torch
@ -179,7 +180,7 @@ class CacheTest(unittest.TestCase):
"""
Tests that static cache works with `torch.export()`
"""
if not is_torch_greater_or_equal_than_2_3:
if not is_torch_greater_or_equal("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
set_seed(0)
@ -230,7 +231,7 @@ class CacheTest(unittest.TestCase):
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
@require_torch_gpu
@require_torch_accelerator
@slow
class CacheIntegrationTest(unittest.TestCase):
def test_dynamic_cache_hard(self):
@ -542,13 +543,17 @@ class CacheIntegrationTest(unittest.TestCase):
def test_static_cache_beam_search(self):
pass
@require_torch_gpu
@require_torch_accelerator
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
@ -566,13 +571,17 @@ class CacheIntegrationTest(unittest.TestCase):
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
assert torch.all(original_output == offloaded_output).item()
@require_torch_gpu
@require_torch_accelerator
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
@ -585,12 +594,20 @@ class CacheIntegrationTest(unittest.TestCase):
}
original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
torch.cuda.reset_peak_memory_stats(device)
torch_accelerator_module = None
if device.type == "cuda":
torch_accelerator_module = torch.cuda
elif device.type == "xpu":
torch_accelerator_module = torch.xpu
torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=original, **inputs)
original_peak_memory = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
original_peak_memory = torch_accelerator_module.max_memory_allocated(device)
torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=offloaded, **inputs)
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}")
assert offloaded_peak_memory < original_peak_memory
@require_torch_gpu