Fix Llama4 (#38222)

Update modeling_llama4.py
This commit is contained in:
Cyril Vallez 2025-05-20 16:00:46 +02:00 committed by GitHub
parent 3f0b7d0fac
commit b591d925be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -144,7 +144,7 @@ class Llama4TextMoe(nn.Module):
def forward(self, hidden_states):
batch, seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
tokens_per_expert = batch * seq_len