Removed unnecessary transpose in Switch Transformer Routing (#33582)

removed switch transformer routing transpose
This commit is contained in:
karan-uppal3 2024-10-04 21:09:03 +05:30 committed by GitHub
parent 78ef58325c
commit 614660fdb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -298,7 +298,7 @@ class SwitchTransformersSparseMLP(nn.Module):
router_mask = router_mask.bool()
batch_size, seq_len, num_experts = router_mask.shape
idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
idx_mask = router_mask.reshape(batch_size * seq_len, num_experts).sum(dim=0)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index