[caches] Raise exception on offloaded static caches + multi device (#37974)

* skip tests on >1 gpu

* add todo
This commit is contained in:
Joao Gante 2025-05-08 14:37:36 +01:00 committed by GitHub
parent 4279057d70
commit f2b59c6173
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 5 deletions

View File

@ -2028,6 +2028,13 @@ class OffloadedHybridCache(HybridChunkedCache):
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
):
super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map)
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer
unique_devices = set(layer_device_map.values())
if len(unique_devices) > 1:
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")
self.offload_device = torch.device(offload_device)
# Create new CUDA stream for parallel prefetching.
self._prefetch_stream = torch.cuda.Stream() if torch._C._get_accelerator().type == "cuda" else None
@ -2280,6 +2287,13 @@ class OffloadedStaticCache(StaticCache):
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super(Cache, self).__init__()
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer
unique_devices = set(layer_device_map.values())
if len(unique_devices) > 1:
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])

View File

@ -198,19 +198,24 @@ class CacheIntegrationTest(unittest.TestCase):
)
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
def _skip_on_uninstalled_cache_dependencies(self, cache_implementation):
"""Function to skip tests on missing cache dependencies, given a cache implementation"""
def _skip_on_failed_cache_prerequisites(self, cache_implementation):
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
# Installed dependencies
if cache_implementation == "quantized" and not is_optimum_quanto_available():
self.skipTest("Quanto is not available")
# Devices
if "offloaded" in cache_implementation:
has_accelerator = torch_device is not None and torch_device != "cpu"
if not has_accelerator:
self.skipTest("Offloaded caches require an accelerator")
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
if torch.cuda.device_count() != 1:
self.skipTest("Offloaded static caches require exactly 1 GPU")
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_batched(self, cache_implementation):
"""Sanity check: caches' `.update` function expects batched inputs"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
self._skip_on_failed_cache_prerequisites(cache_implementation)
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
@ -239,7 +244,7 @@ class CacheIntegrationTest(unittest.TestCase):
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
(an output sequence contains multiple beam indices).
"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
self._skip_on_failed_cache_prerequisites(cache_implementation)
if cache_implementation == "offloaded_hybrid_chunked":
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
@ -273,7 +278,7 @@ class CacheIntegrationTest(unittest.TestCase):
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the cache"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
self._skip_on_failed_cache_prerequisites(cache_implementation)
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]