mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Llama4: remove redundant transpose of router_logits (#37468)
* Llama4: remove redundant transpose of router_logits * Fix formatting
This commit is contained in:
parent
6f7ea1cf00
commit
ecaeee66bc
@ -159,14 +159,12 @@ class Llama4TextMoe(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
batch, seq_len, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
router_logits = self.router(hidden_states).transpose(0, 1)
|
||||
router_logits = self.router(hidden_states)
|
||||
tokens_per_expert = batch * seq_len
|
||||
|
||||
router_top_value, router_indices = torch.topk(router_logits.transpose(0, 1), self.top_k, dim=1)
|
||||
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
||||
router_scores = (
|
||||
torch.full_like(router_logits.transpose(0, 1), float("-inf"))
|
||||
.scatter_(1, router_indices, router_top_value)
|
||||
.transpose(0, 1)
|
||||
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1)
|
||||
)
|
||||
# We do this to make sure we have -inf for non topK tokens before going through the !
|
||||
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
|
||||
|
Loading…
Reference in New Issue
Block a user