These MoEs do not support static cache (atm)

This commit is contained in:
Ivar Flakstad 2025-06-29 20:50:41 +02:00
parent c8064bea9a
commit a9703d0c4d
No known key found for this signature in database
GPG Key ID: CECCB5795C12F640
3 changed files with 11 additions and 2 deletions

View File

@ -510,7 +510,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_static_cache = False
_supports_attention_backend = True
def _init_weights(self, module):
@ -531,6 +531,8 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
@auto_docstring
class DeepseekV3Model(DeepseekV3PreTrainedModel):
_supports_static_cache = False
_keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
def __init__(self, config: DeepseekV3Config):
@ -664,6 +666,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
_supports_static_cache = False
def __init__(self, config):
super().__init__(config)

View File

@ -338,6 +338,8 @@ class DeepseekV3DecoderLayer(LlamaDecoderLayer, nn.Module):
class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
_supports_static_cache = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
@ -355,10 +357,14 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
class DeepseekV3Model(LlamaModel):
_supports_static_cache = False
_keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
class DeepseekV3ForCausalLM(LlamaForCausalLM):
_supports_static_cache = False
pass

View File

@ -430,7 +430,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_static_cache = False
_supports_attention_backend = True
def _init_weights(self, module):