GraniteMoeHybrid: Allow for only shared expert case. (#38801)

* Allow for only shared expert case.

* Style
This commit is contained in:
Shawn Tan 2025-06-16 11:15:42 -04:00 committed by GitHub
parent a7593a1d1f
commit 5ab0f447ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 7 deletions

View File

@ -504,7 +504,8 @@ class GraniteMoeDecoderLayer(GradientCheckpointingLayer):
self.hidden_size = config.hidden_size
self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx)
self.block_sparse_moe = GraniteMoeMoE(config)
if config.num_local_experts > 0:
self.block_sparse_moe = GraniteMoeMoE(config)
self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -1051,7 +1051,8 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
self.hidden_size = config.hidden_size
# Either attention or mamba will be initialized, depending on the layer type.
self.self_attn = None
self.block_sparse_moe = GraniteMoeHybridMoE(config)
if config.num_local_experts > 0:
self.block_sparse_moe = GraniteMoeHybridMoE(config)
self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -1065,6 +1066,9 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
self.self_attn = GraniteMoeHybridAttention(config, layer_idx)
self.layer_type = config.layers_block_type[layer_idx]
# Accept 0 experts: skip MoE if num_local_experts == 0
self.has_experts = getattr(config, "num_local_experts", 0) > 0
def forward(
self,
hidden_states: torch.Tensor,
@ -1131,9 +1135,14 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
if self.has_experts:
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
else:
hidden_states = self.shared_mlp(hidden_states)
router_logits = None
hidden_states = residual + hidden_states * self.residual_multiplier
outputs = (hidden_states,)

View File

@ -71,6 +71,9 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
self.self_attn = GraniteMoeHybridAttention(config, layer_idx)
self.layer_type = config.layers_block_type[layer_idx]
# Accept 0 experts: skip MoE if num_local_experts == 0
self.has_experts = getattr(config, "num_local_experts", 0) > 0
def forward(
self,
hidden_states: torch.Tensor,
@ -137,9 +140,14 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
if self.has_experts:
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
else:
hidden_states = self.shared_mlp(hidden_states)
router_logits = None
hidden_states = residual + hidden_states * self.residual_multiplier
outputs = (hidden_states,)

View File

@ -412,7 +412,8 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
self.hidden_size = config.hidden_size
self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx)
self.block_sparse_moe = GraniteMoeSharedMoE(config)
if config.num_local_experts > 0:
self.block_sparse_moe = GraniteMoeSharedMoE(config)
self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)