fix static cache data type miss-match (#34799)

* fix gptj data type missmatch

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add low precision static cache tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix low-precision static cache tests

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* avoid config change

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* change data type convert in cache copy

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* cast key value after k v out

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng 2024-11-25 23:59:38 +08:00 committed by GitHub
parent b13916c09d
commit a464afbe2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 27 deletions

View File

@ -1217,6 +1217,8 @@ class StaticCache(Cache):
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if cache_position is None:
k_out.copy_(key_states)

View File

@ -1901,36 +1901,41 @@ class GenerationTesterMixin:
seq_length = main_input.shape[-1]
max_new_tokens = 20
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"return_dict_in_generate": True, # Required to return `past_key_values`
"output_scores": True,
"use_cache": True,
}
for dtype in (torch.float32, torch.float16):
model = model_class(config).to(torch_device).to(dtype).eval()
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"return_dict_in_generate": True, # Required to return `past_key_values`
"output_scores": True,
"use_cache": True,
}
static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static")
static_cache_generation = model.generate(
**generation_kwargs, **inputs_dict, cache_implementation="static"
)
# Check 1: The cache shapes must match the expected shapes
max_cache_len = seq_length + max_new_tokens
config = config.text_config if hasattr(config, "text_config") else config
head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
# Check 1: The cache shapes must match the expected shapes
max_cache_len = seq_length + max_new_tokens
text_config = config.text_config if hasattr(config, "text_config") else config
head_dim = (
text_config.head_dim
if hasattr(text_config, "head_dim")
else text_config.hidden_size // text_config.num_attention_heads
)
num_key_value_heads = (
text_config.num_attention_heads
if getattr(text_config, "num_key_value_heads", None) is None
else text_config.num_key_value_heads
)
num_hidden_layers = text_config.num_hidden_layers
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
# Check 2: The outputs must be similar to the case with dynamic cache
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
# Check 2: The outputs must be similar to the case with dynamic cache
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
@require_optimum_quanto
@pytest.mark.generate