Llama4: remove redundant transpose of router_logits (#37468)

* Llama4: remove redundant transpose of router_logits

* Fix formatting
This commit is contained in:
Pavel Belevich 2025-04-15 07:29:26 -04:00 committed by GitHub
parent 6f7ea1cf00
commit ecaeee66bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -159,14 +159,12 @@ class Llama4TextMoe(nn.Module):
def forward(self, hidden_states):
batch, seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
router_logits = self.router(hidden_states).transpose(0, 1)
router_logits = self.router(hidden_states)
tokens_per_expert = batch * seq_len
router_top_value, router_indices = torch.topk(router_logits.transpose(0, 1), self.top_k, dim=1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
router_scores = (
torch.full_like(router_logits.transpose(0, 1), float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1)
)
# We do this to make sure we have -inf for non topK tokens before going through the !
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!