mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[caches] Raise exception on offloaded static caches + multi device (#37974)
* skip tests on >1 gpu * add todo
This commit is contained in:
parent
4279057d70
commit
f2b59c6173
@ -2028,6 +2028,13 @@ class OffloadedHybridCache(HybridChunkedCache):
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
):
|
||||
super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map)
|
||||
|
||||
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
||||
# track of the original device of each layer
|
||||
unique_devices = set(layer_device_map.values())
|
||||
if len(unique_devices) > 1:
|
||||
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")
|
||||
|
||||
self.offload_device = torch.device(offload_device)
|
||||
# Create new CUDA stream for parallel prefetching.
|
||||
self._prefetch_stream = torch.cuda.Stream() if torch._C._get_accelerator().type == "cuda" else None
|
||||
@ -2280,6 +2287,13 @@ class OffloadedStaticCache(StaticCache):
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super(Cache, self).__init__()
|
||||
|
||||
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
||||
# track of the original device of each layer
|
||||
unique_devices = set(layer_device_map.values())
|
||||
if len(unique_devices) > 1:
|
||||
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")
|
||||
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
|
||||
|
@ -198,19 +198,24 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
|
||||
|
||||
def _skip_on_uninstalled_cache_dependencies(self, cache_implementation):
|
||||
"""Function to skip tests on missing cache dependencies, given a cache implementation"""
|
||||
def _skip_on_failed_cache_prerequisites(self, cache_implementation):
|
||||
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
||||
# Installed dependencies
|
||||
if cache_implementation == "quantized" and not is_optimum_quanto_available():
|
||||
self.skipTest("Quanto is not available")
|
||||
# Devices
|
||||
if "offloaded" in cache_implementation:
|
||||
has_accelerator = torch_device is not None and torch_device != "cpu"
|
||||
if not has_accelerator:
|
||||
self.skipTest("Offloaded caches require an accelerator")
|
||||
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
||||
if torch.cuda.device_count() != 1:
|
||||
self.skipTest("Offloaded static caches require exactly 1 GPU")
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_batched(self, cache_implementation):
|
||||
"""Sanity check: caches' `.update` function expects batched inputs"""
|
||||
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
|
||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
||||
|
||||
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||
|
||||
@ -239,7 +244,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
|
||||
(an output sequence contains multiple beam indices).
|
||||
"""
|
||||
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
|
||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
||||
if cache_implementation == "offloaded_hybrid_chunked":
|
||||
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
|
||||
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
|
||||
@ -273,7 +278,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_extra_left_padding(self, cache_implementation):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the cache"""
|
||||
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
|
||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
||||
|
||||
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user