mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
parent
31bb662db1
commit
5e2183f344
@ -9,12 +9,7 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .utils import (
|
||||
is_hqq_available,
|
||||
is_optimum_quanto_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from .utils import is_hqq_available, is_optimum_quanto_available, logging
|
||||
from .utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
@ -24,7 +19,7 @@ if is_hqq_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Cache(torch.nn.Module):
|
||||
class Cache:
|
||||
"""
|
||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||
"""
|
||||
@ -1140,18 +1135,10 @@ class StaticCache(Cache):
|
||||
layer_device = self.device
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
# Notes:
|
||||
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
||||
# it is not needed anyway)
|
||||
# 2. `torch.export()` requires mutations to be registered as buffers.
|
||||
if not is_torchdynamo_compiling():
|
||||
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
||||
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
||||
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
||||
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||
# preventing compiled graph breaks when updating the cache.
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
|
@ -16,10 +16,7 @@ from ..utils.import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers import PreTrainedModel, StaticCache
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
||||
|
||||
|
||||
@ -72,9 +69,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
config=self.model.config,
|
||||
batch_size=self.model.generation_config.cache_config.batch_size,
|
||||
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
|
||||
dtype=self.model.dtype,
|
||||
device=self.model.generation_config.cache_config.device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
for i in range(len(self.static_cache.key_cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
|
||||
|
||||
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
||||
if self.is_causal:
|
||||
causal_mask = torch.tril(
|
||||
@ -109,12 +110,15 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
"""
|
||||
_, seqlen = input_ids.shape
|
||||
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
past_key_values = self.static_cache
|
||||
|
||||
outs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attn_mask,
|
||||
position_ids=cache_position.unsqueeze(0),
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.static_cache,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
return outs.logits
|
||||
@ -143,7 +147,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
prompt_token_len = prompt_token_ids.shape[-1]
|
||||
max_generation_length = prompt_token_len + max_new_tokens
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("static_cache.key_cache"):
|
||||
if buffer_name.startswith("key_cache"):
|
||||
max_cache_len = buffer.shape[2]
|
||||
max_generation_length = min(max_generation_length, max_cache_len)
|
||||
break
|
||||
|
@ -215,11 +215,11 @@ class CacheTest(unittest.TestCase):
|
||||
# Check if the exported model is configured with the `StaticCache` correctly
|
||||
n_static_key_caches = n_static_value_caches = 0
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("static_cache.key_cache"):
|
||||
if buffer_name.startswith("key_cache"):
|
||||
self.assertTrue(buffer.shape[0] == batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_static_key_caches = n_static_key_caches + 1
|
||||
if buffer_name.startswith("static_cache.value_cache"):
|
||||
if buffer_name.startswith("value_cache"):
|
||||
self.assertTrue(buffer.shape[0] == batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_static_value_caches = n_static_value_caches + 1
|
||||
@ -619,4 +619,4 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
|
||||
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
||||
] # fmt: skip
|
||||
self.assertTrue(responses == EXPECTED_DECODED_TEXT)
|
||||
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
||||
|
Loading…
Reference in New Issue
Block a user