Skip non-selected experts for mixtral and qwen2_moe (#32429)

* Skip non-selected experts for mixtral and qwen2_moe

* Fix: tensor tolist()

* WIP: tokenization test

* fix modular source of truth

* nits

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Kerry 2025-04-08 23:41:28 +08:00 committed by GitHub
parent 35f0f5b5da
commit aab0878327
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 7 deletions

View File

@ -135,11 +135,10 @@ class MixtralSparseMoeBlock(nn.Module):
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist()
for expert_idx in expert_hitted:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)

View File

@ -209,11 +209,10 @@ class MixtralSparseMoeBlock(nn.Module):
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist()
for expert_idx in expert_hitted:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)

View File

@ -616,7 +616,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist()
for expert_idx in expert_hitted:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])