mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00
enable OffloadedCache on XPU from PyTorch 2.7 (#36654)
* fix "Cannot copy out of meta tensor; no data!" issue for BartForConditionalGeneration model
* follow Marc's suggestion to use _tie_weights to fix
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
* enable OffloadedCache on XPU since PyTorch 2.7
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
* fix style
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
* don't change bart
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
* make code more concise per review comments
Signed-off-by: N <matrix.yao@intel.com>
* fix review comments
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
* Revert "fix review comments"
This reverts commit acf1484b86
.
* fix review comments
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
* fix style
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
---------
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
Signed-off-by: N <matrix.yao@intel.com>
Co-authored-by: root <root@a4bf01945cfe.jf.intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
e8d960329e
commit
b11050d6a2
@ -9,7 +9,7 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .utils import is_hqq_available, is_optimum_quanto_available, logging
|
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging
|
||||||
|
|
||||||
|
|
||||||
if is_hqq_available():
|
if is_hqq_available():
|
||||||
@ -537,10 +537,10 @@ class DynamicCache(Cache):
|
|||||||
|
|
||||||
class OffloadedCache(DynamicCache):
|
class OffloadedCache(DynamicCache):
|
||||||
"""
|
"""
|
||||||
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
|
A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory.
|
||||||
Useful for generating from models with very long context.
|
Useful for generating from models with very long context.
|
||||||
|
|
||||||
In addition to the default CUDA stream, where all forward() computations happen,
|
In addition to the default accelerator stream, where all forward() computations happen,
|
||||||
this class uses another stream, the prefetch stream, which it creates itself.
|
this class uses another stream, the prefetch stream, which it creates itself.
|
||||||
Since scheduling of operations on separate streams happens independently, this class uses
|
Since scheduling of operations on separate streams happens independently, this class uses
|
||||||
the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
|
the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
|
||||||
@ -549,17 +549,21 @@ class OffloadedCache(DynamicCache):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
if not torch.cuda.is_available():
|
if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())):
|
||||||
raise RuntimeError("OffloadedCache can only be used with a GPU")
|
raise RuntimeError(
|
||||||
|
"OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "")
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.original_device = []
|
self.original_device = []
|
||||||
self.prefetch_stream = torch.cuda.Stream()
|
self.prefetch_stream = None
|
||||||
|
self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream()
|
||||||
self.beam_idx = None # used to delay beam search operations
|
self.beam_idx = None # used to delay beam search operations
|
||||||
|
|
||||||
def prefetch_layer(self, layer_idx: int):
|
def prefetch_layer(self, layer_idx: int):
|
||||||
"Starts prefetching the next layer cache"
|
"Starts prefetching the next layer cache"
|
||||||
if layer_idx < len(self):
|
if layer_idx < len(self):
|
||||||
with torch.cuda.stream(self.prefetch_stream):
|
with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream):
|
||||||
# Prefetch next layer tensors to GPU
|
# Prefetch next layer tensors to GPU
|
||||||
device = self.original_device[layer_idx]
|
device = self.original_device[layer_idx]
|
||||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
||||||
@ -577,7 +581,10 @@ class OffloadedCache(DynamicCache):
|
|||||||
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
||||||
if layer_idx < len(self):
|
if layer_idx < len(self):
|
||||||
# Evict the previous layer if necessary
|
# Evict the previous layer if necessary
|
||||||
torch.cuda.current_stream().synchronize()
|
if is_torch_greater_or_equal("2.7"):
|
||||||
|
torch.accelerator.current_stream().synchronize()
|
||||||
|
else:
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
self.evict_previous_layer(layer_idx)
|
self.evict_previous_layer(layer_idx)
|
||||||
# Load current layer cache to its original device if not already there
|
# Load current layer cache to its original device if not already there
|
||||||
original_device = self.original_device[layer_idx]
|
original_device = self.original_device[layer_idx]
|
||||||
|
@ -1062,7 +1062,9 @@ def is_torch_greater_or_equal(library_version: str):
|
|||||||
if not _is_package_available("torch"):
|
if not _is_package_available("torch"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)
|
return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse(
|
||||||
|
library_version
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_torchdistx_available():
|
def is_torchdistx_available():
|
||||||
|
@ -27,6 +27,7 @@ from transformers.testing_utils import (
|
|||||||
require_non_xpu,
|
require_non_xpu,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
@ -48,7 +49,7 @@ if is_torch_available():
|
|||||||
StaticCache,
|
StaticCache,
|
||||||
convert_and_export_with_cache,
|
convert_and_export_with_cache,
|
||||||
)
|
)
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
from transformers.utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@ -179,7 +180,7 @@ class CacheTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Tests that static cache works with `torch.export()`
|
Tests that static cache works with `torch.export()`
|
||||||
"""
|
"""
|
||||||
if not is_torch_greater_or_equal_than_2_3:
|
if not is_torch_greater_or_equal("2.3"):
|
||||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
@ -230,7 +231,7 @@ class CacheTest(unittest.TestCase):
|
|||||||
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
@slow
|
@slow
|
||||||
class CacheIntegrationTest(unittest.TestCase):
|
class CacheIntegrationTest(unittest.TestCase):
|
||||||
def test_dynamic_cache_hard(self):
|
def test_dynamic_cache_hard(self):
|
||||||
@ -542,13 +543,17 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
def test_static_cache_beam_search(self):
|
def test_static_cache_beam_search(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
|
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
|
||||||
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
|
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
|
||||||
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
|
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
||||||
|
|
||||||
input_text = "Fun fact:"
|
input_text = "Fun fact:"
|
||||||
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
||||||
common = {
|
common = {
|
||||||
@ -566,13 +571,17 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
|
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
|
||||||
assert torch.all(original_output == offloaded_output).item()
|
assert torch.all(original_output == offloaded_output).item()
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
|
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
|
||||||
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
|
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
|
||||||
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
|
if not is_torch_greater_or_equal("2.7") and device.type == "xpu":
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
||||||
|
|
||||||
input_text = "Fun fact:"
|
input_text = "Fun fact:"
|
||||||
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
||||||
common = {
|
common = {
|
||||||
@ -585,12 +594,20 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
original = GenerationConfig(**common)
|
original = GenerationConfig(**common)
|
||||||
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
||||||
torch.cuda.reset_peak_memory_stats(device)
|
|
||||||
|
torch_accelerator_module = None
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch_accelerator_module = torch.cuda
|
||||||
|
elif device.type == "xpu":
|
||||||
|
torch_accelerator_module = torch.xpu
|
||||||
|
|
||||||
|
torch_accelerator_module.reset_peak_memory_stats(device)
|
||||||
model.generate(generation_config=original, **inputs)
|
model.generate(generation_config=original, **inputs)
|
||||||
original_peak_memory = torch.cuda.max_memory_allocated(device)
|
original_peak_memory = torch_accelerator_module.max_memory_allocated(device)
|
||||||
torch.cuda.reset_peak_memory_stats(device)
|
torch_accelerator_module.reset_peak_memory_stats(device)
|
||||||
model.generate(generation_config=offloaded, **inputs)
|
model.generate(generation_config=offloaded, **inputs)
|
||||||
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
|
offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
|
||||||
|
print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}")
|
||||||
assert offloaded_peak_memory < original_peak_memory
|
assert offloaded_peak_memory < original_peak_memory
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
Loading…
Reference in New Issue
Block a user