mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
parent
1cd110c6cb
commit
362fa37da2
@ -1539,92 +1539,133 @@ class GenerationTesterMixin:
|
|||||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self, custom_all_cache_shapes=None):
|
||||||
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
|
"""
|
||||||
# standard KV cache format is important for a consistent API (and for advanced generation methods).
|
Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the
|
||||||
|
expected cache shapes.
|
||||||
|
Having a standard KV cache format is important for a consistent API (and for advanced generation methods).
|
||||||
|
"""
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
# If it doesn't support cache, pass the test
|
# 1. If it doesn't support cache, skip the test
|
||||||
if not hasattr(config.get_text_config(), "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
|
|
||||||
model = model_class(config).to(torch_device)
|
model = model_class(config).to(torch_device)
|
||||||
|
model = model.eval()
|
||||||
if "use_cache" not in inputs:
|
if "use_cache" not in inputs:
|
||||||
inputs["use_cache"] = True
|
inputs["use_cache"] = True
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
|
||||||
if "past_key_values" not in outputs:
|
if "past_key_values" not in outputs:
|
||||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||||
|
|
||||||
|
# 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided)
|
||||||
|
past_kv = outputs["past_key_values"]
|
||||||
|
is_legacy_cache = not isinstance(past_kv, Cache)
|
||||||
|
|
||||||
text_config = config.get_text_config()
|
text_config = config.get_text_config()
|
||||||
num_hidden_layers = (
|
num_decoder_layers = (
|
||||||
getattr(text_config, "decoder_layers", None)
|
getattr(text_config, "decoder_layers", None)
|
||||||
or getattr(text_config, "num_decoder_layers", None)
|
or getattr(text_config, "num_decoder_layers", None)
|
||||||
or text_config.num_hidden_layers
|
or text_config.num_hidden_layers
|
||||||
)
|
)
|
||||||
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
|
|
||||||
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
|
||||||
per_head_embed_dim = embed_dim // num_attention_heads
|
|
||||||
|
|
||||||
# some models have different num-head for query vs key/value so we need to assign correct value
|
if custom_all_cache_shapes is None:
|
||||||
# BUT only after `per_head_embed_dim` is set
|
num_query_attention_heads = getattr(
|
||||||
num_attention_heads = (
|
text_config, "decoder_attention_heads", text_config.num_attention_heads
|
||||||
|
)
|
||||||
|
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
||||||
|
per_head_embed_dim = embed_dim // num_query_attention_heads
|
||||||
|
num_key_value_heads = (
|
||||||
text_config.num_key_value_heads
|
text_config.num_key_value_heads
|
||||||
if getattr(text_config, "num_key_value_heads", None) is not None
|
if getattr(text_config, "num_key_value_heads", None) is not None
|
||||||
else num_attention_heads
|
else num_query_attention_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
past_kv = outputs["past_key_values"]
|
|
||||||
self.assertEqual(len(past_kv), num_hidden_layers)
|
|
||||||
|
|
||||||
# Encoder-Decoder checks
|
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
# encoder-decoder models usually don't have text config
|
|
||||||
# below is needed only for Pix2Struct which we cannot modify now due to BC
|
|
||||||
config = config.get_text_config()
|
|
||||||
encoder_num_attention_heads = (
|
encoder_num_attention_heads = (
|
||||||
config.encoder_attention_heads
|
text_config.encoder_attention_heads
|
||||||
if hasattr(config, "encoder_attention_heads")
|
if hasattr(text_config, "encoder_attention_heads")
|
||||||
else config.num_attention_heads
|
else text_config.num_attention_heads
|
||||||
)
|
)
|
||||||
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
||||||
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
||||||
for i in range(num_hidden_layers):
|
|
||||||
self.assertEqual(len(past_kv[i]), 4) # K V for the decoder + K V for the encoder = 4
|
|
||||||
self.assertEqual(
|
|
||||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
|
||||||
)
|
|
||||||
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
|
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
|
||||||
# autoregressive generation, I'm keeping the test general and not checking the 3rd dim
|
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
|
||||||
self.assertEqual(
|
default_cross_attention_shape = (
|
||||||
(past_kv[i][2].shape[0], past_kv[i][2].shape[1], past_kv[i][2].shape[3]),
|
batch_size,
|
||||||
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
|
encoder_num_attention_heads,
|
||||||
|
encoder_per_head_embed_dim,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
|
||||||
(past_kv[i][3].shape[0], past_kv[i][3].shape[1], past_kv[i][3].shape[3]),
|
all_cache_shapes = [
|
||||||
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
|
[
|
||||||
)
|
default_self_attention_shape,
|
||||||
|
default_self_attention_shape,
|
||||||
# Decoder-only checks
|
default_cross_attention_shape,
|
||||||
|
default_cross_attention_shape,
|
||||||
|
]
|
||||||
|
for _ in range(num_decoder_layers)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
# TODO: this line is only needed because of imagegpt, where "pixel_values" = "input_ids". Fix the
|
batch_size, seq_length = inputs["input_ids"].shape
|
||||||
# tests in imagegpt such that `prepare_config_and_inputs_for_common` returns the later (and the other
|
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
|
||||||
# tests use it)
|
all_cache_shapes = [
|
||||||
key = "input_ids" if "input_ids" in inputs else "pixel_values"
|
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
|
||||||
batch_size, seq_length = inputs[key].shape
|
]
|
||||||
for i in range(num_hidden_layers):
|
|
||||||
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
|
else:
|
||||||
self.assertEqual(
|
all_cache_shapes = custom_all_cache_shapes
|
||||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
|
||||||
|
# 3. Check cache shapes
|
||||||
|
# 3.1. Encoder-Decoder checks
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
num_cache_decoder_layers = (
|
||||||
|
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
|
||||||
|
for i in range(num_decoder_layers):
|
||||||
|
if is_legacy_cache:
|
||||||
|
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
|
||||||
|
|
||||||
|
# Self attention
|
||||||
|
self_attention_layer_key_cache = (
|
||||||
|
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
|
||||||
)
|
)
|
||||||
|
self_attention_layer_value_cache = (
|
||||||
|
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
|
||||||
|
)
|
||||||
|
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||||
|
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||||
|
|
||||||
|
# Cross attention (ignore 3rd dim, see default shape preparation)
|
||||||
|
cross_attention_layer_key_cache = (
|
||||||
|
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
|
||||||
|
)
|
||||||
|
cross_attention_layer_value_cache = (
|
||||||
|
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
|
||||||
|
)
|
||||||
|
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
|
||||||
|
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
|
||||||
|
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
|
||||||
|
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
|
||||||
|
|
||||||
|
# 3.2. Decoder-only checks
|
||||||
|
else:
|
||||||
|
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
|
||||||
|
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||||
|
|
||||||
|
for i in range(num_decoder_layers):
|
||||||
|
if is_legacy_cache:
|
||||||
|
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||||
|
|
||||||
|
# Self attention
|
||||||
|
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
|
||||||
|
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
|
||||||
|
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||||
|
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||||
|
@ -429,9 +429,23 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||||
|
|
||||||
@unittest.skip(reason="Deepseek-V3 uses MLA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
"""
|
||||||
|
Overwritting to pass the expected cache shapes (Deepseek-V3 uses MLA so the cache shapes are non-standard)
|
||||||
|
"""
|
||||||
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
batch_size, seq_length = inputs["input_ids"].shape
|
||||||
|
# difference: last dim
|
||||||
|
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||||
|
v_embed_dim = config.v_head_dim
|
||||||
|
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
||||||
|
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
||||||
|
# build the full cache shapes
|
||||||
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
all_cache_shapes = [
|
||||||
|
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
|
||||||
|
]
|
||||||
|
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
|
||||||
|
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
@slow
|
@slow
|
||||||
|
@ -264,51 +264,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
# Falcon can have different numbers of KV-heads than the number of query heads, so we need
|
|
||||||
# to override this test to use the right head counts.
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
# If it doesn't support cache, pass the test
|
|
||||||
if not hasattr(config, "use_cache"):
|
|
||||||
self.skipTest(reason="Model does not support cache")
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device)
|
|
||||||
if "use_cache" not in inputs:
|
|
||||||
inputs["use_cache"] = True
|
|
||||||
outputs = model(**inputs)
|
|
||||||
|
|
||||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
|
||||||
if "past_key_values" not in outputs:
|
|
||||||
self.skipTest(reason="Model does not return past_key_values")
|
|
||||||
|
|
||||||
num_hidden_layers = (
|
|
||||||
getattr(config, "decoder_layers", None)
|
|
||||||
or getattr(config, "num_decoder_layers", None)
|
|
||||||
or config.num_hidden_layers
|
|
||||||
)
|
|
||||||
num_attention_heads = getattr(config, "num_kv_heads", config.num_attention_heads)
|
|
||||||
embed_dim = getattr(config, "d_model", config.hidden_size)
|
|
||||||
per_head_embed_dim = embed_dim // num_attention_heads
|
|
||||||
|
|
||||||
past_kv = outputs["past_key_values"]
|
|
||||||
self.assertEqual(len(past_kv), num_hidden_layers)
|
|
||||||
|
|
||||||
batch_size, seq_length = inputs["input_ids"].shape
|
|
||||||
for i in range(num_hidden_layers):
|
|
||||||
if config.new_decoder_architecture:
|
|
||||||
num_attention_heads = config.num_attention_heads
|
|
||||||
elif config.multi_query:
|
|
||||||
num_attention_heads = 1
|
|
||||||
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
|
|
||||||
self.assertEqual(
|
|
||||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",)])
|
@parameterized.expand([("linear",), ("dynamic",)])
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||||
|
@ -296,10 +296,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="Gemma uses GQA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -264,10 +264,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="Glm uses GQA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@is_flaky()
|
@is_flaky()
|
||||||
def test_custom_4d_attention_mask(self):
|
def test_custom_4d_attention_mask(self):
|
||||||
"""Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky."""
|
"""Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky."""
|
||||||
|
@ -222,12 +222,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
|
|
||||||
)
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
def test_flash_attn_2_fp32_ln(self):
|
||||||
pass
|
pass
|
||||||
|
@ -319,6 +319,10 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
def test_left_padding_compatibility(self):
|
def test_left_padding_compatibility(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Model inputs don't fit test pattern") # and it's not used enough to be worth fixing :)
|
||||||
|
def test_past_key_values_format(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# We will verify our results on an image of cute cats
|
# We will verify our results on an image of cute cats
|
||||||
def prepare_img():
|
def prepare_img():
|
||||||
|
@ -251,10 +251,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
@unittest.skip(reason="JetMoe uses MoA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -292,10 +292,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="Mistral uses GQA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -324,10 +324,6 @@ class TFMistralModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
@unittest.skip("Mistral uses GQA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("Vocab resizing is not supported")
|
@unittest.skip("Vocab resizing is not supported")
|
||||||
def test_save_load_after_resize_token_embeddings(self):
|
def test_save_load_after_resize_token_embeddings(self):
|
||||||
pass
|
pass
|
||||||
|
@ -291,10 +291,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="Mixtral uses GQA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -409,7 +409,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
# overridden because mllama has special cache for self and cross attentions
|
# overridden because mllama is not an encoder-decoder model, but has encoder-decoder-like cache
|
||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
|
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
|
||||||
# standard KV cache format is important for a consistent API (and for advanced generation methods).
|
# standard KV cache format is important for a consistent API (and for advanced generation methods).
|
||||||
|
@ -303,10 +303,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="Qwen2 uses GQA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -331,10 +331,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="Qwen2Moe uses GQA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -306,10 +306,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
super().test_past_key_values_format()
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -325,10 +325,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
super().test_past_key_values_format()
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -223,10 +223,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||||||
config_and_inputs[0].position_embedding_type = type
|
config_and_inputs[0].position_embedding_type = type
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skip(reason="RecurrentGemma does not return pkv")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="RecurrentGemma only supports sdpa")
|
@unittest.skip(reason="RecurrentGemma only supports sdpa")
|
||||||
def test_eager_matches_sdpa_generate(self):
|
def test_eager_matches_sdpa_generate(self):
|
||||||
pass
|
pass
|
||||||
|
@ -281,10 +281,6 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="Starcoder2 uses GQA on all models so the KV cache is a non standard format")
|
|
||||||
def test_past_key_values_format(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -322,14 +322,22 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("Zamba2 has a hybrid cache")
|
|
||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
r"""
|
|
||||||
Zamba2's cache shape depends on whether a given layer is mamba or attention.
|
|
||||||
For mamba layers, the KV cache has shape is empty and has shape [batch_size, 0].
|
|
||||||
The shape checks of this test assume instead that every layer has an attention cache, so we skip it.
|
|
||||||
"""
|
"""
|
||||||
pass
|
Overwritting to pass the expected cache shapes (Zamba2 has cache shape = [batch_size, 0] for mamba layers)
|
||||||
|
"""
|
||||||
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
batch_size, seq_length = inputs["input_ids"].shape
|
||||||
|
per_head_embed_dim = config.attention_head_dim # note: this one is not a common attribute name
|
||||||
|
self_attention_cache_shape = (batch_size, config.num_key_value_heads, seq_length, per_head_embed_dim)
|
||||||
|
# build the full cache shapes, including mamba layers
|
||||||
|
all_cache_shapes = []
|
||||||
|
for i in range(config.num_hidden_layers):
|
||||||
|
if config.layers_block_type[i] == "mamba":
|
||||||
|
all_cache_shapes.append([torch.Size([batch_size, 0]), torch.Size([batch_size, 0])])
|
||||||
|
else:
|
||||||
|
all_cache_shapes.append([self_attention_cache_shape, self_attention_cache_shape])
|
||||||
|
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
|
||||||
|
|
||||||
@unittest.skip(reason="Zamba2 has hybrid cache.")
|
@unittest.skip(reason="Zamba2 has hybrid cache.")
|
||||||
def test_generate_continue_from_inputs_embeds(self):
|
def test_generate_continue_from_inputs_embeds(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user