diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 96b6c7334b1..f1495ddc8c0 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -298,7 +298,7 @@ class SwitchTransformersSparseMLP(nn.Module): router_mask = router_mask.bool() batch_size, seq_len, num_experts = router_mask.shape - idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0) + idx_mask = router_mask.reshape(batch_size * seq_len, num_experts).sum(dim=0) idx_mask = torch.nonzero(idx_mask, as_tuple=True)[ 0 ].tolist() # length: number of "activated" expert / value: index