mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix the initialization of the cache when we have multi gpu (#33303)
* init cache multi-gpu * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * switch to execution device map * naming more consistant * fix * mutually exclusive device * added an integration example * remove useless check * suggestion from joao + typing * fix couple of typo and add test * revert check --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
dfd31158ee
commit
6cc4dfe3f1
@ -1030,6 +1030,9 @@ class StaticCache(Cache):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
|
||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -1060,6 +1063,7 @@ class StaticCache(Cache):
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
max_batch_size: Optional[int] = None,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if max_batch_size is not None:
|
||||
@ -1088,16 +1092,20 @@ class StaticCache(Cache):
|
||||
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
||||
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
for idx in range(config.num_hidden_layers):
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
if layer_device_map is not None:
|
||||
layer_device = layer_device_map[idx]
|
||||
else:
|
||||
layer_device = device
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
# Notes:
|
||||
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
||||
# it is not needed anyway)
|
||||
# 2. `torch.export()` requires mutations to be registered as buffers.
|
||||
if not is_torchdynamo_compiling():
|
||||
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
||||
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
||||
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
||||
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
||||
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
||||
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
@ -1130,9 +1138,9 @@ class StaticCache(Cache):
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
|
||||
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
|
||||
@ -1201,6 +1209,9 @@ class SlidingWindowCache(StaticCache):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
|
||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -1231,6 +1242,7 @@ class SlidingWindowCache(StaticCache):
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
max_batch_size: Optional[int] = None,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||
@ -1247,6 +1259,7 @@ class SlidingWindowCache(StaticCache):
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_batch_size=max_batch_size,
|
||||
layer_device_map=layer_device_map,
|
||||
)
|
||||
|
||||
def update(
|
||||
@ -1280,7 +1293,6 @@ class SlidingWindowCache(StaticCache):
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
try:
|
||||
cache_position.to(device=k_out.device)
|
||||
k_out.index_copy_(2, cache_position, key_states)
|
||||
v_out.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
@ -1495,6 +1507,9 @@ class HybridCache(Cache):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
|
||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -1525,6 +1540,7 @@ class HybridCache(Cache):
|
||||
device: Union[torch.device, str] = "cpu",
|
||||
dtype: torch.dtype = torch.float32,
|
||||
max_batch_size: Optional[int] = None,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if max_batch_size is not None:
|
||||
@ -1562,11 +1578,15 @@ class HybridCache(Cache):
|
||||
self.head_dim,
|
||||
)
|
||||
for i in range(config.num_hidden_layers):
|
||||
if layer_device_map is not None:
|
||||
layer_device = layer_device_map[i]
|
||||
else:
|
||||
layer_device = device
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
@ -1617,8 +1637,6 @@ class HybridCache(Cache):
|
||||
) -> Tuple[torch.Tensor]:
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
sliding_window = cache_kwargs.get("sliding_window")
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
if sliding_window:
|
||||
|
@ -1446,12 +1446,39 @@ class GenerationMixin:
|
||||
# models. May cause trobles with non-text modalities.
|
||||
cache_dtype = self.get_output_embeddings().weight.dtype
|
||||
|
||||
def get_layer_device_map(execution_device_map: Optional[dict] = None):
|
||||
if execution_device_map is None or len(execution_device_map) <= 1:
|
||||
return None
|
||||
layer_device_map = {}
|
||||
for layer in execution_device_map:
|
||||
for idx in range(self.config.num_hidden_layers):
|
||||
if f".{idx}." in f"{layer}.":
|
||||
layer_device_map[idx] = execution_device_map[layer]
|
||||
break
|
||||
for idx in range(self.config.num_hidden_layers):
|
||||
if idx not in layer_device_map:
|
||||
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
||||
return layer_device_map
|
||||
|
||||
execution_device_map = None
|
||||
# Taken from dispatch_model from accelerate.
|
||||
# This is needed here if we don't want to make changes in accelerate in order to save execution_device
|
||||
# For offloaded case, we need to get the execution device, not just the device where it is offloaded
|
||||
if hasattr(self, "hf_device_map"):
|
||||
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
|
||||
execution_device_map = {
|
||||
name: main_device if device in ["cpu", "disk"] else device
|
||||
for name, device in self.hf_device_map.items()
|
||||
}
|
||||
layer_device_map = get_layer_device_map(execution_device_map)
|
||||
|
||||
cache_kwargs = {
|
||||
"config": self.config if hasattr(self.config, "text_config") else self.config,
|
||||
"max_batch_size": batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"device": device,
|
||||
"dtype": cache_dtype,
|
||||
"layer_device_map": layer_device_map,
|
||||
}
|
||||
self._cache = cache_cls(**cache_kwargs)
|
||||
if requires_cross_attention_cache:
|
||||
|
@ -3444,6 +3444,91 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(test_bos_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id is None)
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_multi_gpu
|
||||
def test_generate_with_static_cache_multi_gpu(self):
|
||||
"""
|
||||
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
|
||||
"""
|
||||
# need to split manually as auto doesn't work well with unbalanced model
|
||||
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 20,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
|
||||
results = model.generate(input_ids, **generation_kwargs)
|
||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_multi_gpu
|
||||
def test_init_static_cache_multi_gpu(self):
|
||||
"""
|
||||
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
|
||||
"""
|
||||
# need to split manually as auto doesn't work well with unbalanced model
|
||||
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 20,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
|
||||
# TODO: We need to raise a warning in case the cache is not set correctly
|
||||
# with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
|
||||
# past_key_values = StaticCache(
|
||||
# config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
|
||||
# )
|
||||
# results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||
|
||||
# deduced from the device_map : layer 0 on device 0 and layer 1 on device 1
|
||||
layer_device_map = {0: 0, 1: 1}
|
||||
past_key_values = StaticCache(
|
||||
config=model.config,
|
||||
batch_size=1,
|
||||
max_cache_len=30,
|
||||
device=torch_device,
|
||||
dtype=model.dtype,
|
||||
layer_device_map=layer_device_map,
|
||||
)
|
||||
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user