diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index e55e188bca7..37519fe435a 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -504,7 +504,8 @@ class GraniteMoeDecoderLayer(GradientCheckpointingLayer): self.hidden_size = config.hidden_size self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx) - self.block_sparse_moe = GraniteMoeMoE(config) + if config.num_local_experts > 0: + self.block_sparse_moe = GraniteMoeMoE(config) self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 032081bc04d..b1480085601 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1051,7 +1051,8 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): self.hidden_size = config.hidden_size # Either attention or mamba will be initialized, depending on the layer type. self.self_attn = None - self.block_sparse_moe = GraniteMoeHybridMoE(config) + if config.num_local_experts > 0: + self.block_sparse_moe = GraniteMoeHybridMoE(config) self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1065,6 +1066,9 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): self.self_attn = GraniteMoeHybridAttention(config, layer_idx) self.layer_type = config.layers_block_type[layer_idx] + # Accept 0 experts: skip MoE if num_local_experts == 0 + self.has_experts = getattr(config, "num_local_experts", 0) > 0 + def forward( self, hidden_states: torch.Tensor, @@ -1131,9 +1135,14 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) - hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + if self.has_experts: + moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + else: + hidden_states = self.shared_mlp(hidden_states) + router_logits = None + hidden_states = residual + hidden_states * self.residual_multiplier outputs = (hidden_states,) diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index f9dc0ec3ea6..160e0aa1bf3 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -71,6 +71,9 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer): self.self_attn = GraniteMoeHybridAttention(config, layer_idx) self.layer_type = config.layers_block_type[layer_idx] + # Accept 0 experts: skip MoE if num_local_experts == 0 + self.has_experts = getattr(config, "num_local_experts", 0) > 0 + def forward( self, hidden_states: torch.Tensor, @@ -137,9 +140,14 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer): # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) - hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + if self.has_experts: + moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + else: + hidden_states = self.shared_mlp(hidden_states) + router_logits = None + hidden_states = residual + hidden_states * self.residual_multiplier outputs = (hidden_states,) diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 139535b009d..da7ade3cf48 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -412,7 +412,8 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer): self.hidden_size = config.hidden_size self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx) - self.block_sparse_moe = GraniteMoeSharedMoE(config) + if config.num_local_experts > 0: + self.block_sparse_moe = GraniteMoeSharedMoE(config) self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)