mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
fix static cache data type miss-match (#34799)
* fix gptj data type missmatch Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add low precision static cache tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix low-precision static cache tests * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * avoid config change Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * change data type convert in cache copy Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * cast key value after k v out Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
b13916c09d
commit
a464afbe2a
@ -1217,6 +1217,8 @@ class StaticCache(Cache):
|
||||
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
if cache_position is None:
|
||||
k_out.copy_(key_states)
|
||||
|
@ -1901,36 +1901,41 @@ class GenerationTesterMixin:
|
||||
seq_length = main_input.shape[-1]
|
||||
max_new_tokens = 20
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"output_scores": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
for dtype in (torch.float32, torch.float16):
|
||||
model = model_class(config).to(torch_device).to(dtype).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"output_scores": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static")
|
||||
static_cache_generation = model.generate(
|
||||
**generation_kwargs, **inputs_dict, cache_implementation="static"
|
||||
)
|
||||
|
||||
# Check 1: The cache shapes must match the expected shapes
|
||||
max_cache_len = seq_length + max_new_tokens
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
head_dim = (
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
num_key_value_heads = (
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
|
||||
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
|
||||
# Check 1: The cache shapes must match the expected shapes
|
||||
max_cache_len = seq_length + max_new_tokens
|
||||
text_config = config.text_config if hasattr(config, "text_config") else config
|
||||
head_dim = (
|
||||
text_config.head_dim
|
||||
if hasattr(text_config, "head_dim")
|
||||
else text_config.hidden_size // text_config.num_attention_heads
|
||||
)
|
||||
num_key_value_heads = (
|
||||
text_config.num_attention_heads
|
||||
if getattr(text_config, "num_key_value_heads", None) is None
|
||||
else text_config.num_key_value_heads
|
||||
)
|
||||
num_hidden_layers = text_config.num_hidden_layers
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
|
||||
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
|
||||
|
||||
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
|
||||
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
|
||||
|
||||
@require_optimum_quanto
|
||||
@pytest.mark.generate
|
||||
|
Loading…
Reference in New Issue
Block a user