remote output_router_logits

This commit is contained in:
ryan u 2025-02-18 19:48:25 +09:00
parent 67f1f0ca15
commit f264f800d0
2 changed files with 6 additions and 24 deletions

View File

@ -157,7 +157,7 @@ class DeepseekV3TopkRouter(nn.Module):
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights, router_logits
return topk_indices, topk_weights
@torch.no_grad()
def get_topk_indices(self, scores):
@ -202,11 +202,11 @@ class DeepseekV3MoE(nn.Module):
def forward(self, hidden_states):
residuals = hidden_states
orig_shape = hidden_states.shape
topk_indices, topk_weights, router_logits = self.gate(hidden_states)
topk_indices, topk_weights = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states, router_logits
return hidden_states
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
@ -444,7 +444,6 @@ class DeepseekV3DecoderLayer(nn.Module):
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
@ -472,19 +471,11 @@ class DeepseekV3DecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states, router_logits = hidden_states
else:
router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if output_router_logits:
outputs += (router_logits,)
return outputs

View File

@ -82,7 +82,7 @@ class DeepseekV3TopkRouter(nn.Module):
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights, router_logits
return topk_indices, topk_weights
@torch.no_grad()
def get_topk_indices(self, scores):
@ -127,11 +127,11 @@ class DeepseekV3MoE(nn.Module):
def forward(self, hidden_states):
residuals = hidden_states
orig_shape = hidden_states.shape
topk_indices, topk_weights, router_logits = self.gate(hidden_states)
topk_indices, topk_weights = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states, router_logits
return hidden_states
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
@ -291,7 +291,6 @@ class DeepseekV3DecoderLayer(nn.Module):
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
@ -319,19 +318,11 @@ class DeepseekV3DecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states, router_logits = hidden_states
else:
router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if output_router_logits:
outputs += (router_logits,)
return outputs