make Llama4TextMoe forward more readable (#37529)

* update forward of Llama4TextMoe

* remove redudant transpose

* fix formatting

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
JJJYmmm 2025-05-28 17:54:45 +08:00 committed by GitHub
parent defeb04299
commit 565a0052ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -138,36 +138,23 @@ class Llama4TextMoe(nn.Module):
self.shared_expert = Llama4TextMLP(config) self.shared_expert = Llama4TextMLP(config)
def forward(self, hidden_states): def forward(self, hidden_states):
batch, seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, self.hidden_dim) hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states) router_logits = self.router(hidden_states)
tokens_per_expert = batch * seq_len
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
router_scores = ( router_scores = (
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1)
) )
# We do this to make sure we have -inf for non topK tokens before going through the !
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
router_indices = (
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim) routed_in = hidden_states.repeat(self.num_experts, 1)
routed_in = torch.gather(
input=hidden_states,
dim=0,
index=router_indices,
).to(hidden_states.device)
# we gather inputs corresponding to each expert based on the router indices
routed_in = routed_in * router_scores.reshape(-1, 1) routed_in = routed_in * router_scores.reshape(-1, 1)
routed_out = self.experts(routed_in) routed_out = self.experts(routed_in)
out = self.shared_expert(hidden_states) out = self.shared_expert(hidden_states)
# now that we finished expert computation -> we scatter add because we gathered previously out.add_(routed_out.reshape(self.num_experts, -1, self.hidden_dim).sum(dim=0))
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
# this scales a lot better if you do EP!
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim))
return out, router_scores return out, router_scores