mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
remote output_router_logits
This commit is contained in:
parent
67f1f0ca15
commit
f264f800d0
@ -157,7 +157,7 @@ class DeepseekV3TopkRouter(nn.Module):
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights /= denominator
|
||||
topk_weights = topk_weights * self.routed_scaling_factor
|
||||
return topk_indices, topk_weights, router_logits
|
||||
return topk_indices, topk_weights
|
||||
|
||||
@torch.no_grad()
|
||||
def get_topk_indices(self, scores):
|
||||
@ -202,11 +202,11 @@ class DeepseekV3MoE(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
residuals = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_indices, topk_weights, router_logits = self.gate(hidden_states)
|
||||
topk_indices, topk_weights = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
hidden_states = hidden_states + self.shared_experts(residuals)
|
||||
return hidden_states, router_logits
|
||||
return hidden_states
|
||||
|
||||
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
||||
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
|
||||
@ -444,7 +444,6 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
@ -472,19 +471,11 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states, router_logits = hidden_states
|
||||
else:
|
||||
router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -82,7 +82,7 @@ class DeepseekV3TopkRouter(nn.Module):
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights /= denominator
|
||||
topk_weights = topk_weights * self.routed_scaling_factor
|
||||
return topk_indices, topk_weights, router_logits
|
||||
return topk_indices, topk_weights
|
||||
|
||||
@torch.no_grad()
|
||||
def get_topk_indices(self, scores):
|
||||
@ -127,11 +127,11 @@ class DeepseekV3MoE(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
residuals = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_indices, topk_weights, router_logits = self.gate(hidden_states)
|
||||
topk_indices, topk_weights = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
hidden_states = hidden_states + self.shared_experts(residuals)
|
||||
return hidden_states, router_logits
|
||||
return hidden_states
|
||||
|
||||
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
||||
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
|
||||
@ -291,7 +291,6 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
@ -319,19 +318,11 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states, router_logits = hidden_states
|
||||
else:
|
||||
router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user