mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: fix end to end compilation (#32465)
This commit is contained in:
parent
6a03942db7
commit
3d8bd11942
@ -1024,19 +1024,22 @@ class StaticCache(Cache):
|
||||
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
||||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
for idx in range(config.num_hidden_layers):
|
||||
# Note: `torch.export()`` requires mutations to be registered as buffers.
|
||||
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
||||
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
||||
key_cache = getattr(self, f"key_cache_{idx}")
|
||||
value_cache = getattr(self, f"value_cache_{idx}")
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
||||
# it is not needed anyway)
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
# Notes:
|
||||
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
||||
# it is not needed anyway)
|
||||
# 2. `torch.export()` requires mutations to be registered as buffers.
|
||||
if not is_torchdynamo_compiling():
|
||||
torch._dynamo.mark_static_address(key_cache)
|
||||
torch._dynamo.mark_static_address(value_cache)
|
||||
self.key_cache.append(key_cache)
|
||||
self.value_cache.append(value_cache)
|
||||
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
||||
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
||||
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
||||
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
@ -1429,7 +1429,9 @@ class GenerationMixin:
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
|
||||
def _get_cache(
|
||||
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
||||
) -> Cache:
|
||||
"""
|
||||
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
||||
new `generate` call requires a larger cache or uses a different batch size.
|
||||
@ -1477,7 +1479,7 @@ class GenerationMixin:
|
||||
"config": self.config,
|
||||
"max_batch_size": max_batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"device": self.device,
|
||||
"device": device,
|
||||
"dtype": cache_dtype,
|
||||
}
|
||||
self._cache = cache_cls(**cache_kwargs)
|
||||
@ -1813,12 +1815,11 @@ class GenerationMixin:
|
||||
"issue: https://github.com/huggingface/transformers/issues/28981"
|
||||
)
|
||||
model_kwargs[cache_name] = self._get_cache(
|
||||
generation_config.cache_implementation,
|
||||
getattr(generation_config, "num_beams", 1)
|
||||
* getattr(generation_config, "num_return_sequences", 1)
|
||||
* batch_size,
|
||||
generation_config.max_length,
|
||||
model_kwargs,
|
||||
cache_implementation=generation_config.cache_implementation,
|
||||
max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
|
||||
max_cache_len=generation_config.max_length,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
elif generation_config.cache_implementation == "quantized":
|
||||
if not self._supports_quantized_cache:
|
||||
|
Loading…
Reference in New Issue
Block a user