Generate: fix end to end compilation (#32465)

This commit is contained in:
Joao Gante 2024-08-06 15:06:47 +01:00 committed by GitHub
parent 6a03942db7
commit 3d8bd11942
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 20 deletions

View File

@ -1024,19 +1024,22 @@ class StaticCache(Cache):
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
# Note: `torch.export()`` requires mutations to be registered as buffers.
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
key_cache = getattr(self, f"key_cache_{idx}")
value_cache = getattr(self, f"value_cache_{idx}")
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(key_cache)
torch._dynamo.mark_static_address(value_cache)
self.key_cache.append(key_cache)
self.value_cache.append(value_cache)
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def update(
self,

View File

@ -1429,7 +1429,9 @@ class GenerationMixin:
model_kwargs["cache_position"] = cache_position
return model_kwargs
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
def _get_cache(
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache or uses a different batch size.
@ -1477,7 +1479,7 @@ class GenerationMixin:
"config": self.config,
"max_batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": self.device,
"device": device,
"dtype": cache_dtype,
}
self._cache = cache_cls(**cache_kwargs)
@ -1813,12 +1815,11 @@ class GenerationMixin:
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs[cache_name] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1)
* getattr(generation_config, "num_return_sequences", 1)
* batch_size,
generation_config.max_length,
model_kwargs,
cache_implementation=generation_config.cache_implementation,
max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
max_cache_len=generation_config.max_length,
device=device,
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache: