mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
GraniteMoeHybrid: Allow for only shared expert case. (#38801)
* Allow for only shared expert case. * Style
This commit is contained in:
parent
a7593a1d1f
commit
5ab0f447ab
@ -504,7 +504,8 @@ class GraniteMoeDecoderLayer(GradientCheckpointingLayer):
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx)
|
||||
self.block_sparse_moe = GraniteMoeMoE(config)
|
||||
if config.num_local_experts > 0:
|
||||
self.block_sparse_moe = GraniteMoeMoE(config)
|
||||
self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
@ -1051,7 +1051,8 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
|
||||
self.hidden_size = config.hidden_size
|
||||
# Either attention or mamba will be initialized, depending on the layer type.
|
||||
self.self_attn = None
|
||||
self.block_sparse_moe = GraniteMoeHybridMoE(config)
|
||||
if config.num_local_experts > 0:
|
||||
self.block_sparse_moe = GraniteMoeHybridMoE(config)
|
||||
self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -1065,6 +1066,9 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
|
||||
self.self_attn = GraniteMoeHybridAttention(config, layer_idx)
|
||||
self.layer_type = config.layers_block_type[layer_idx]
|
||||
|
||||
# Accept 0 experts: skip MoE if num_local_experts == 0
|
||||
self.has_experts = getattr(config, "num_local_experts", 0) > 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1131,9 +1135,14 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
|
||||
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
|
||||
if self.has_experts:
|
||||
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
|
||||
else:
|
||||
hidden_states = self.shared_mlp(hidden_states)
|
||||
router_logits = None
|
||||
|
||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
@ -71,6 +71,9 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
|
||||
self.self_attn = GraniteMoeHybridAttention(config, layer_idx)
|
||||
self.layer_type = config.layers_block_type[layer_idx]
|
||||
|
||||
# Accept 0 experts: skip MoE if num_local_experts == 0
|
||||
self.has_experts = getattr(config, "num_local_experts", 0) > 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -137,9 +140,14 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
|
||||
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
|
||||
if self.has_experts:
|
||||
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
|
||||
else:
|
||||
hidden_states = self.shared_mlp(hidden_states)
|
||||
router_logits = None
|
||||
|
||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
@ -412,7 +412,8 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx)
|
||||
self.block_sparse_moe = GraniteMoeSharedMoE(config)
|
||||
if config.num_local_experts > 0:
|
||||
self.block_sparse_moe = GraniteMoeSharedMoE(config)
|
||||
self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user