diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 2f6202ef8d1..fe77ea4a58c 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -138,36 +138,23 @@ class Llama4TextMoe(nn.Module): self.shared_expert = Llama4TextMLP(config) def forward(self, hidden_states): - batch, seq_len, hidden_dim = hidden_states.shape hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) - tokens_per_expert = batch * seq_len router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + router_scores = ( torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) - # We do this to make sure we have -inf for non topK tokens before going through the ! - # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this! - router_indices = ( - torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) - ) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim) - routed_in = torch.gather( - input=hidden_states, - dim=0, - index=router_indices, - ).to(hidden_states.device) - # we gather inputs corresponding to each expert based on the router indices + routed_in = hidden_states.repeat(self.num_experts, 1) routed_in = routed_in * router_scores.reshape(-1, 1) routed_out = self.experts(routed_in) + out = self.shared_expert(hidden_states) - # now that we finished expert computation -> we scatter add because we gathered previously - # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound - # this scales a lot better if you do EP! - out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim)) + out.add_(routed_out.reshape(self.num_experts, -1, self.hidden_dim).sum(dim=0)) + return out, router_scores