diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d2d9929b912..d7a158e9e8d 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -301,10 +301,8 @@ class SwitchTransformersSparseMLP(nn.Module): router_mask, router_probs, router_logits = self.router(hidden_states) expert_index = torch.argmax(router_mask, dim=-1) - # The routers introduced might not always map all the tokens, to a router, which means that some hidden states - # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the selected ones. - - next_states = hidden_states.clone() + # If a token gets dropped, we just set it to zero such that it does not get updated. + next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype) router_mask = router_mask.bool() batch_size, seq_len, num_experts = router_mask.shape diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 0af2f32549b..18475d96034 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -42,6 +42,7 @@ if is_torch_available(): SwitchTransformersEncoderModel, SwitchTransformersForConditionalGeneration, SwitchTransformersModel, + SwitchTransformersSparseMLP, SwitchTransformersTop1Router, ) from transformers.models.switch_transformers.modeling_switch_transformers import ( @@ -1133,3 +1134,16 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase): for i in range(0, BATCH_SIZE, 2): self.assertEqual(batch_output[i], batch_output[i + 1]) + + +@require_torch +class SwitchTransformersSparseMLPTests(unittest.TestCase): + def test_token_dropping(self): + r""" + This test checks if the token dropping actually drops tokens. + """ + config = SwitchTransformersConfig(expert_capacity=0) # we drop everything + moe = SwitchTransformersSparseMLP(config) + dropped_token_results = moe(torch.randn(2, 3, 768))[0] + + assert (dropped_token_results == 0).all(), f"Some tokens not dropped: {dropped_token_results}."