mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
make Llama4TextMoe forward more readable (#37529)
* update forward of Llama4TextMoe * remove redudant transpose * fix formatting --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
defeb04299
commit
565a0052ed
@ -138,36 +138,23 @@ class Llama4TextMoe(nn.Module):
|
||||
self.shared_expert = Llama4TextMLP(config)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
batch, seq_len, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
router_logits = self.router(hidden_states)
|
||||
tokens_per_expert = batch * seq_len
|
||||
|
||||
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
||||
|
||||
router_scores = (
|
||||
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1)
|
||||
)
|
||||
# We do this to make sure we have -inf for non topK tokens before going through the !
|
||||
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
|
||||
router_indices = (
|
||||
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
|
||||
)
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
|
||||
|
||||
router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim)
|
||||
routed_in = torch.gather(
|
||||
input=hidden_states,
|
||||
dim=0,
|
||||
index=router_indices,
|
||||
).to(hidden_states.device)
|
||||
# we gather inputs corresponding to each expert based on the router indices
|
||||
routed_in = hidden_states.repeat(self.num_experts, 1)
|
||||
routed_in = routed_in * router_scores.reshape(-1, 1)
|
||||
routed_out = self.experts(routed_in)
|
||||
|
||||
out = self.shared_expert(hidden_states)
|
||||
# now that we finished expert computation -> we scatter add because we gathered previously
|
||||
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
|
||||
# this scales a lot better if you do EP!
|
||||
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim))
|
||||
out.add_(routed_out.reshape(self.num_experts, -1, self.hidden_dim).sum(dim=0))
|
||||
|
||||
return out, router_scores
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user