enable OffloadedCache on XPU from PyTorch 2.7 (#36654)

* fix "Cannot copy out of meta tensor; no data!" issue for BartForConditionalGeneration model

* follow Marc's suggestion to use _tie_weights to fix

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* enable OffloadedCache on XPU since PyTorch 2.7

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* don't change bart

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* make code more concise per review comments

Signed-off-by: N <matrix.yao@intel.com>

* fix review comments

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* Revert "fix review comments"

This reverts commit acf1484b86.

* fix review comments

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* fix style

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

---------

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
Signed-off-by: N <matrix.yao@intel.com>
Co-authored-by: root <root@a4bf01945cfe.jf.intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Yao Matrix 2025-03-19 22:15:52 +08:00 committed by GitHub
parent e8d960329e
commit b11050d6a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 18 deletions

View File

@ -9,7 +9,7 @@ import torch
from packaging import version from packaging import version
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_optimum_quanto_available, logging from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
if is_hqq_available(): if is_hqq_available():
@ -537,10 +537,10 @@ class DynamicCache(Cache):
class OffloadedCache(DynamicCache): class OffloadedCache(DynamicCache):
""" """
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory.
Useful for generating from models with very long context. Useful for generating from models with very long context.
In addition to the default CUDA stream, where all forward() computations happen, In addition to the default accelerator stream, where all forward() computations happen,
this class uses another stream, the prefetch stream, which it creates itself. this class uses another stream, the prefetch stream, which it creates itself.
Since scheduling of operations on separate streams happens independently, this class uses Since scheduling of operations on separate streams happens independently, this class uses
the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
@ -549,17 +549,21 @@ class OffloadedCache(DynamicCache):
""" """
def __init__(self) -> None: def __init__(self) -> None:
if not torch.cuda.is_available(): if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())):
raise RuntimeError("OffloadedCache can only be used with a GPU") raise RuntimeError(
"OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "")
)
super().__init__() super().__init__()
self.original_device = [] self.original_device = []
self.prefetch_stream = torch.cuda.Stream() self.prefetch_stream = None
self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream()
self.beam_idx = None # used to delay beam search operations self.beam_idx = None # used to delay beam search operations
def prefetch_layer(self, layer_idx: int): def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache" "Starts prefetching the next layer cache"
if layer_idx < len(self): if layer_idx < len(self):
with torch.cuda.stream(self.prefetch_stream): with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU # Prefetch next layer tensors to GPU
device = self.original_device[layer_idx] device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
@ -577,7 +581,10 @@ class OffloadedCache(DynamicCache):
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self): if layer_idx < len(self):
# Evict the previous layer if necessary # Evict the previous layer if necessary
torch.cuda.current_stream().synchronize() if is_torch_greater_or_equal("2.7"):
torch.accelerator.current_stream().synchronize()
else:
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx) self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there # Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx] original_device = self.original_device[layer_idx]

View File

@ -1062,7 +1062,9 @@ def is_torch_greater_or_equal(library_version: str):
if not _is_package_available("torch"): if not _is_package_available("torch"):
return False return False
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse(
library_version
)
def is_torchdistx_available(): def is_torchdistx_available():

View File

@ -27,6 +27,7 @@ from transformers.testing_utils import (
require_non_xpu, require_non_xpu,
require_read_token, require_read_token,
require_torch, require_torch,
require_torch_accelerator,
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
slow, slow,
@ -48,7 +49,7 @@ if is_torch_available():
StaticCache, StaticCache,
convert_and_export_with_cache, convert_and_export_with_cache,
) )
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 from transformers.utils import is_torch_greater_or_equal
@require_torch @require_torch
@ -179,7 +180,7 @@ class CacheTest(unittest.TestCase):
""" """
Tests that static cache works with `torch.export()` Tests that static cache works with `torch.export()`
""" """
if not is_torch_greater_or_equal_than_2_3: if not is_torch_greater_or_equal("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.") self.skipTest(reason="This test requires torch >= 2.3 to run.")
set_seed(0) set_seed(0)
@ -230,7 +231,7 @@ class CacheTest(unittest.TestCase):
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers) self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
@require_torch_gpu @require_torch_accelerator
@slow @slow
class CacheIntegrationTest(unittest.TestCase): class CacheIntegrationTest(unittest.TestCase):
def test_dynamic_cache_hard(self): def test_dynamic_cache_hard(self):
@ -542,13 +543,17 @@ class CacheIntegrationTest(unittest.TestCase):
def test_static_cache_beam_search(self): def test_static_cache_beam_search(self):
pass pass
@require_torch_gpu @require_torch_accelerator
def test_offloaded_cache_equivalent_to_dynamic_cache(self): def test_offloaded_cache_equivalent_to_dynamic_cache(self):
"""Tests that OffloadedCache produces the same result as the default DynamicCache""" """Tests that OffloadedCache produces the same result as the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct" model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:" input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device) inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = { common = {
@ -566,13 +571,17 @@ class CacheIntegrationTest(unittest.TestCase):
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs): for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
assert torch.all(original_output == offloaded_output).item() assert torch.all(original_output == offloaded_output).item()
@require_torch_gpu @require_torch_accelerator
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self): def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
"""Tests that OffloadedCache uses less memory than the default DynamicCache""" """Tests that OffloadedCache uses less memory than the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct" model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device device = model.device
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:" input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device) inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = { common = {
@ -585,12 +594,20 @@ class CacheIntegrationTest(unittest.TestCase):
} }
original = GenerationConfig(**common) original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common) offloaded = GenerationConfig(cache_implementation="offloaded", **common)
torch.cuda.reset_peak_memory_stats(device)
torch_accelerator_module = None
if device.type == "cuda":
torch_accelerator_module = torch.cuda
elif device.type == "xpu":
torch_accelerator_module = torch.xpu
torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=original, **inputs) model.generate(generation_config=original, **inputs)
original_peak_memory = torch.cuda.max_memory_allocated(device) original_peak_memory = torch_accelerator_module.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device) torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=offloaded, **inputs) model.generate(generation_config=offloaded, **inputs)
offloaded_peak_memory = torch.cuda.max_memory_allocated(device) offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}")
assert offloaded_peak_memory < original_peak_memory assert offloaded_peak_memory < original_peak_memory
@require_torch_gpu @require_torch_gpu