mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Removed unnecessary transpose in Switch Transformer Routing (#33582)
removed switch transformer routing transpose
This commit is contained in:
parent
78ef58325c
commit
614660fdb9
@ -298,7 +298,7 @@ class SwitchTransformersSparseMLP(nn.Module):
|
||||
|
||||
router_mask = router_mask.bool()
|
||||
batch_size, seq_len, num_experts = router_mask.shape
|
||||
idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
|
||||
idx_mask = router_mask.reshape(batch_size * seq_len, num_experts).sum(dim=0)
|
||||
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
|
||||
0
|
||||
].tolist() # length: number of "activated" expert / value: index
|
||||
|
Loading…
Reference in New Issue
Block a user