diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 11c25b28278..558bcfb2e28 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -9,7 +9,7 @@ import torch from packaging import version from .configuration_utils import PretrainedConfig -from .utils import is_hqq_available, is_optimum_quanto_available, logging +from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging if is_hqq_available(): @@ -537,10 +537,10 @@ class DynamicCache(Cache): class OffloadedCache(DynamicCache): """ - A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. + A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. Useful for generating from models with very long context. - In addition to the default CUDA stream, where all forward() computations happen, + In addition to the default accelerator stream, where all forward() computations happen, this class uses another stream, the prefetch stream, which it creates itself. Since scheduling of operations on separate streams happens independently, this class uses the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. @@ -549,17 +549,21 @@ class OffloadedCache(DynamicCache): """ def __init__(self) -> None: - if not torch.cuda.is_available(): - raise RuntimeError("OffloadedCache can only be used with a GPU") + if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())): + raise RuntimeError( + "OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "") + ) + super().__init__() self.original_device = [] - self.prefetch_stream = torch.cuda.Stream() + self.prefetch_stream = None + self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream() self.beam_idx = None # used to delay beam search operations def prefetch_layer(self, layer_idx: int): "Starts prefetching the next layer cache" if layer_idx < len(self): - with torch.cuda.stream(self.prefetch_stream): + with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream): # Prefetch next layer tensors to GPU device = self.original_device[layer_idx] self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) @@ -577,7 +581,10 @@ class OffloadedCache(DynamicCache): "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." if layer_idx < len(self): # Evict the previous layer if necessary - torch.cuda.current_stream().synchronize() + if is_torch_greater_or_equal("2.7"): + torch.accelerator.current_stream().synchronize() + else: + torch.cuda.current_stream().synchronize() self.evict_previous_layer(layer_idx) # Load current layer cache to its original device if not already there original_device = self.original_device[layer_idx] diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 4f2724f1c8a..ad4e685b24f 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1062,7 +1062,9 @@ def is_torch_greater_or_equal(library_version: str): if not _is_package_available("torch"): return False - return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) + return version.parse(version.parse(importlib.metadata.version("torch")).base_version) >= version.parse( + library_version + ) def is_torchdistx_available(): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index fc7617e6493..efe4e6af5c1 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -27,6 +27,7 @@ from transformers.testing_utils import ( require_non_xpu, require_read_token, require_torch, + require_torch_accelerator, require_torch_gpu, require_torch_multi_gpu, slow, @@ -48,7 +49,7 @@ if is_torch_available(): StaticCache, convert_and_export_with_cache, ) - from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 + from transformers.utils import is_torch_greater_or_equal @require_torch @@ -179,7 +180,7 @@ class CacheTest(unittest.TestCase): """ Tests that static cache works with `torch.export()` """ - if not is_torch_greater_or_equal_than_2_3: + if not is_torch_greater_or_equal("2.3"): self.skipTest(reason="This test requires torch >= 2.3 to run.") set_seed(0) @@ -230,7 +231,7 @@ class CacheTest(unittest.TestCase): self.assertEqual(n_static_value_caches, model.config.num_hidden_layers) -@require_torch_gpu +@require_torch_accelerator @slow class CacheIntegrationTest(unittest.TestCase): def test_dynamic_cache_hard(self): @@ -542,13 +543,17 @@ class CacheIntegrationTest(unittest.TestCase): def test_static_cache_beam_search(self): pass - @require_torch_gpu + @require_torch_accelerator def test_offloaded_cache_equivalent_to_dynamic_cache(self): """Tests that OffloadedCache produces the same result as the default DynamicCache""" model_name = "microsoft/Phi-3-mini-4k-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) device = model.device + + if not is_torch_greater_or_equal("2.7") and device.type == "xpu": + self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.") + input_text = "Fun fact:" inputs = tokenizer(input_text, return_tensors="pt").to(device) common = { @@ -566,13 +571,17 @@ class CacheIntegrationTest(unittest.TestCase): for original_output, offloaded_output in zip(original_outputs, offloaded_outputs): assert torch.all(original_output == offloaded_output).item() - @require_torch_gpu + @require_torch_accelerator def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self): """Tests that OffloadedCache uses less memory than the default DynamicCache""" model_name = "microsoft/Phi-3-mini-4k-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) device = model.device + + if not is_torch_greater_or_equal("2.7") and device.type == "xpu": + self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.") + input_text = "Fun fact:" inputs = tokenizer(input_text, return_tensors="pt").to(device) common = { @@ -585,12 +594,20 @@ class CacheIntegrationTest(unittest.TestCase): } original = GenerationConfig(**common) offloaded = GenerationConfig(cache_implementation="offloaded", **common) - torch.cuda.reset_peak_memory_stats(device) + + torch_accelerator_module = None + if device.type == "cuda": + torch_accelerator_module = torch.cuda + elif device.type == "xpu": + torch_accelerator_module = torch.xpu + + torch_accelerator_module.reset_peak_memory_stats(device) model.generate(generation_config=original, **inputs) - original_peak_memory = torch.cuda.max_memory_allocated(device) - torch.cuda.reset_peak_memory_stats(device) + original_peak_memory = torch_accelerator_module.max_memory_allocated(device) + torch_accelerator_module.reset_peak_memory_stats(device) model.generate(generation_config=offloaded, **inputs) - offloaded_peak_memory = torch.cuda.max_memory_allocated(device) + offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device) + print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}") assert offloaded_peak_memory < original_peak_memory @require_torch_gpu