From bde41d69b47c37e0dc1704cb4cd1a2a4709a4136 Mon Sep 17 00:00:00 2001 From: Mario Michael Krell <172859788+mario-aws@users.noreply.github.com> Date: Thu, 10 Apr 2025 07:58:57 -0700 Subject: [PATCH] Correctly drop tokens in SwitchTransformer (#37123) Previously, the identity function was used for dropped tokens with a weight from the expert that was not applied to the hidden states. This was misleading, because dropping means, the expert weight is zero. Instead of trying to fix the weight, we take an easier approach by initializing with zeros. Fixes issue https://github.com/huggingface/transformers/issues/37017 --- .../modeling_switch_transformers.py | 6 ++---- .../test_modeling_switch_transformers.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d2d9929b912..d7a158e9e8d 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -301,10 +301,8 @@ class SwitchTransformersSparseMLP(nn.Module): router_mask, router_probs, router_logits = self.router(hidden_states) expert_index = torch.argmax(router_mask, dim=-1) - # The routers introduced might not always map all the tokens, to a router, which means that some hidden states - # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the selected ones. - - next_states = hidden_states.clone() + # If a token gets dropped, we just set it to zero such that it does not get updated. + next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype) router_mask = router_mask.bool() batch_size, seq_len, num_experts = router_mask.shape diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 0af2f32549b..18475d96034 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -42,6 +42,7 @@ if is_torch_available(): SwitchTransformersEncoderModel, SwitchTransformersForConditionalGeneration, SwitchTransformersModel, + SwitchTransformersSparseMLP, SwitchTransformersTop1Router, ) from transformers.models.switch_transformers.modeling_switch_transformers import ( @@ -1133,3 +1134,16 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase): for i in range(0, BATCH_SIZE, 2): self.assertEqual(batch_output[i], batch_output[i + 1]) + + +@require_torch +class SwitchTransformersSparseMLPTests(unittest.TestCase): + def test_token_dropping(self): + r""" + This test checks if the token dropping actually drops tokens. + """ + config = SwitchTransformersConfig(expert_capacity=0) # we drop everything + moe = SwitchTransformersSparseMLP(config) + dropped_token_results = moe(torch.randn(2, 3, 768))[0] + + assert (dropped_token_results == 0).all(), f"Some tokens not dropped: {dropped_token_results}."