[test] update test_past_key_values_format (#37614)

allow custom shapes
This commit is contained in:
Joao Gante 2025-04-22 11:07:34 +01:00 committed by GitHub
parent 1cd110c6cb
commit 362fa37da2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 134 additions and 166 deletions

View File

@ -1539,92 +1539,133 @@ class GenerationTesterMixin:
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
@pytest.mark.generate @pytest.mark.generate
def test_past_key_values_format(self): def test_past_key_values_format(self, custom_all_cache_shapes=None):
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a """
# standard KV cache format is important for a consistent API (and for advanced generation methods). Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the
expected cache shapes.
Having a standard KV cache format is important for a consistent API (and for advanced generation methods).
"""
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test # 1. If it doesn't support cache, skip the test
if not hasattr(config.get_text_config(), "use_cache"): if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
model = model.eval()
if "use_cache" not in inputs: if "use_cache" not in inputs:
inputs["use_cache"] = True inputs["use_cache"] = True
outputs = model(**inputs) outputs = model(**inputs)
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs: if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`") self.skipTest(reason="This model doesn't return `past_key_values`")
# 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided)
past_kv = outputs["past_key_values"]
is_legacy_cache = not isinstance(past_kv, Cache)
text_config = config.get_text_config() text_config = config.get_text_config()
num_hidden_layers = ( num_decoder_layers = (
getattr(text_config, "decoder_layers", None) getattr(text_config, "decoder_layers", None)
or getattr(text_config, "num_decoder_layers", None) or getattr(text_config, "num_decoder_layers", None)
or text_config.num_hidden_layers or text_config.num_hidden_layers
) )
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads
# some models have different num-head for query vs key/value so we need to assign correct value if custom_all_cache_shapes is None:
# BUT only after `per_head_embed_dim` is set num_query_attention_heads = getattr(
num_attention_heads = ( text_config, "decoder_attention_heads", text_config.num_attention_heads
)
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_query_attention_heads
num_key_value_heads = (
text_config.num_key_value_heads text_config.num_key_value_heads
if getattr(text_config, "num_key_value_heads", None) is not None if getattr(text_config, "num_key_value_heads", None) is not None
else num_attention_heads else num_query_attention_heads
) )
past_kv = outputs["past_key_values"]
self.assertEqual(len(past_kv), num_hidden_layers)
# Encoder-Decoder checks
if config.is_encoder_decoder: if config.is_encoder_decoder:
# encoder-decoder models usually don't have text config
# below is needed only for Pix2Struct which we cannot modify now due to BC
config = config.get_text_config()
encoder_num_attention_heads = ( encoder_num_attention_heads = (
config.encoder_attention_heads text_config.encoder_attention_heads
if hasattr(config, "encoder_attention_heads") if hasattr(text_config, "encoder_attention_heads")
else config.num_attention_heads else text_config.num_attention_heads
) )
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
batch_size, seq_length = inputs["decoder_input_ids"].shape batch_size, seq_length = inputs["decoder_input_ids"].shape
for i in range(num_hidden_layers):
self.assertEqual(len(past_kv[i]), 4) # K V for the decoder + K V for the encoder = 4
self.assertEqual(
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
self.assertEqual(
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in # The sequence length for the encoder K V depends on the model. Since it is not manipulated in
# autoregressive generation, I'm keeping the test general and not checking the 3rd dim # autoregressive generation, we're keeping the test general and not checking the 3rd dim
self.assertEqual( default_cross_attention_shape = (
(past_kv[i][2].shape[0], past_kv[i][2].shape[1], past_kv[i][2].shape[3]), batch_size,
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim), encoder_num_attention_heads,
encoder_per_head_embed_dim,
) )
self.assertEqual( default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
(past_kv[i][3].shape[0], past_kv[i][3].shape[1], past_kv[i][3].shape[3]), all_cache_shapes = [
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim), [
) default_self_attention_shape,
default_self_attention_shape,
# Decoder-only checks default_cross_attention_shape,
default_cross_attention_shape,
]
for _ in range(num_decoder_layers)
]
else: else:
# TODO: this line is only needed because of imagegpt, where "pixel_values" = "input_ids". Fix the batch_size, seq_length = inputs["input_ids"].shape
# tests in imagegpt such that `prepare_config_and_inputs_for_common` returns the later (and the other default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
# tests use it) all_cache_shapes = [
key = "input_ids" if "input_ids" in inputs else "pixel_values" [default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
batch_size, seq_length = inputs[key].shape ]
for i in range(num_hidden_layers):
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2 else:
self.assertEqual( all_cache_shapes = custom_all_cache_shapes
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
# 3. Check cache shapes
# 3.1. Encoder-Decoder checks
if config.is_encoder_decoder:
num_cache_decoder_layers = (
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
) )
self.assertEqual( self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
for i in range(num_decoder_layers):
if is_legacy_cache:
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
) )
self_attention_layer_value_cache = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
)
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
# Cross attention (ignore 3rd dim, see default shape preparation)
cross_attention_layer_key_cache = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
)
cross_attention_layer_value_cache = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
)
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
if is_legacy_cache:
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
@pytest.mark.generate @pytest.mark.generate
@parameterized.expand([("greedy", 1), ("beam search", 2)]) @parameterized.expand([("greedy", 1), ("beam search", 2)])

View File

@ -429,9 +429,23 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long) torch.testing.assert_close(yarn_sin_long, original_sin_long)
@unittest.skip(reason="Deepseek-V3 uses MLA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self): def test_past_key_values_format(self):
pass """
Overwritting to pass the expected cache shapes (Deepseek-V3 uses MLA so the cache shapes are non-standard)
"""
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, seq_length = inputs["input_ids"].shape
# difference: last dim
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
v_embed_dim = config.v_head_dim
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
# build the full cache shapes
num_hidden_layers = config.num_hidden_layers
all_cache_shapes = [
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
]
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
@require_torch_sdpa @require_torch_sdpa
@slow @slow

View File

@ -264,51 +264,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_past_key_values_format(self):
# Falcon can have different numbers of KV-heads than the number of query heads, so we need
# to override this test to use the right head counts.
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
self.skipTest(reason="Model does not support cache")
model = model_class(config).to(torch_device)
if "use_cache" not in inputs:
inputs["use_cache"] = True
outputs = model(**inputs)
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs:
self.skipTest(reason="Model does not return past_key_values")
num_hidden_layers = (
getattr(config, "decoder_layers", None)
or getattr(config, "num_decoder_layers", None)
or config.num_hidden_layers
)
num_attention_heads = getattr(config, "num_kv_heads", config.num_attention_heads)
embed_dim = getattr(config, "d_model", config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads
past_kv = outputs["past_key_values"]
self.assertEqual(len(past_kv), num_hidden_layers)
batch_size, seq_length = inputs["input_ids"].shape
for i in range(num_hidden_layers):
if config.new_decoder_architecture:
num_attention_heads = config.num_attention_heads
elif config.multi_query:
num_attention_heads = 1
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
self.assertEqual(
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
self.assertEqual(
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
@parameterized.expand([("linear",), ("dynamic",)]) @parameterized.expand([("linear",), ("dynamic",)])
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon
def test_model_rope_scaling_from_config(self, scaling_type): def test_model_rope_scaling_from_config(self, scaling_type):

View File

@ -296,10 +296,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
) )
@unittest.skip(reason="Gemma uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test

View File

@ -264,10 +264,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
) )
@unittest.skip(reason="Glm uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@is_flaky() @is_flaky()
def test_custom_4d_attention_mask(self): def test_custom_4d_attention_mask(self):
"""Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky.""" """Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky."""

View File

@ -222,12 +222,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip(
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
)
def test_past_key_values_format(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type") @unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self): def test_flash_attn_2_fp32_ln(self):
pass pass

View File

@ -319,6 +319,10 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
def test_left_padding_compatibility(self): def test_left_padding_compatibility(self):
pass pass
@unittest.skip(reason="Model inputs don't fit test pattern") # and it's not used enough to be worth fixing :)
def test_past_key_values_format(self):
pass
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img():

View File

@ -251,10 +251,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip(reason="JetMoe uses MoA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test

View File

@ -292,10 +292,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
) )
@unittest.skip(reason="Mistral uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test

View File

@ -324,10 +324,6 @@ class TFMistralModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("Mistral uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@unittest.skip("Vocab resizing is not supported") @unittest.skip("Vocab resizing is not supported")
def test_save_load_after_resize_token_embeddings(self): def test_save_load_after_resize_token_embeddings(self):
pass pass

View File

@ -291,10 +291,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
) )
@unittest.skip(reason="Mixtral uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test

View File

@ -409,7 +409,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
pass pass
@pytest.mark.generate @pytest.mark.generate
# overridden because mllama has special cache for self and cross attentions # overridden because mllama is not an encoder-decoder model, but has encoder-decoder-like cache
def test_past_key_values_format(self): def test_past_key_values_format(self):
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a # Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
# standard KV cache format is important for a consistent API (and for advanced generation methods). # standard KV cache format is important for a consistent API (and for advanced generation methods).

View File

@ -303,10 +303,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
) )
@unittest.skip(reason="Qwen2 uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test

View File

@ -331,10 +331,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
) )
@unittest.skip(reason="Qwen2Moe uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test

View File

@ -306,10 +306,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
) )
# Ignore copy
def test_past_key_values_format(self):
super().test_past_key_values_format()
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test

View File

@ -325,10 +325,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
) )
# Ignore copy
def test_past_key_values_format(self):
super().test_past_key_values_format()
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test

View File

@ -223,10 +223,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
config_and_inputs[0].position_embedding_type = type config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="RecurrentGemma does not return pkv")
def test_past_key_values_format(self):
pass
@unittest.skip(reason="RecurrentGemma only supports sdpa") @unittest.skip(reason="RecurrentGemma only supports sdpa")
def test_eager_matches_sdpa_generate(self): def test_eager_matches_sdpa_generate(self):
pass pass

View File

@ -281,10 +281,6 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
) )
@unittest.skip(reason="Starcoder2 uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test

View File

@ -322,14 +322,22 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass pass
@unittest.skip("Zamba2 has a hybrid cache")
def test_past_key_values_format(self): def test_past_key_values_format(self):
r"""
Zamba2's cache shape depends on whether a given layer is mamba or attention.
For mamba layers, the KV cache has shape is empty and has shape [batch_size, 0].
The shape checks of this test assume instead that every layer has an attention cache, so we skip it.
""" """
pass Overwritting to pass the expected cache shapes (Zamba2 has cache shape = [batch_size, 0] for mamba layers)
"""
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, seq_length = inputs["input_ids"].shape
per_head_embed_dim = config.attention_head_dim # note: this one is not a common attribute name
self_attention_cache_shape = (batch_size, config.num_key_value_heads, seq_length, per_head_embed_dim)
# build the full cache shapes, including mamba layers
all_cache_shapes = []
for i in range(config.num_hidden_layers):
if config.layers_block_type[i] == "mamba":
all_cache_shapes.append([torch.Size([batch_size, 0]), torch.Size([batch_size, 0])])
else:
all_cache_shapes.append([self_attention_cache_shape, self_attention_cache_shape])
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
@unittest.skip(reason="Zamba2 has hybrid cache.") @unittest.skip(reason="Zamba2 has hybrid cache.")
def test_generate_continue_from_inputs_embeds(self): def test_generate_continue_from_inputs_embeds(self):