Correctly drop tokens in SwitchTransformer (#37123)

Previously, the identity function was used for dropped tokens
with a weight from the expert that was not applied to the hidden states.
This was misleading, because dropping means, the expert weight is zero.
Instead of trying to fix the weight, we take an easier approach by initializing with zeros.

Fixes issue https://github.com/huggingface/transformers/issues/37017
This commit is contained in:
Mario Michael Krell 2025-04-10 07:58:57 -07:00 committed by GitHub
parent 7ecc5b88c0
commit bde41d69b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 4 deletions

View File

@ -301,10 +301,8 @@ class SwitchTransformersSparseMLP(nn.Module):
router_mask, router_probs, router_logits = self.router(hidden_states)
expert_index = torch.argmax(router_mask, dim=-1)
# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the selected ones.
next_states = hidden_states.clone()
# If a token gets dropped, we just set it to zero such that it does not get updated.
next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)
router_mask = router_mask.bool()
batch_size, seq_len, num_experts = router_mask.shape

View File

@ -42,6 +42,7 @@ if is_torch_available():
SwitchTransformersEncoderModel,
SwitchTransformersForConditionalGeneration,
SwitchTransformersModel,
SwitchTransformersSparseMLP,
SwitchTransformersTop1Router,
)
from transformers.models.switch_transformers.modeling_switch_transformers import (
@ -1133,3 +1134,16 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
for i in range(0, BATCH_SIZE, 2):
self.assertEqual(batch_output[i], batch_output[i + 1])
@require_torch
class SwitchTransformersSparseMLPTests(unittest.TestCase):
def test_token_dropping(self):
r"""
This test checks if the token dropping actually drops tokens.
"""
config = SwitchTransformersConfig(expert_capacity=0) # we drop everything
moe = SwitchTransformersSparseMLP(config)
dropped_token_results = moe(torch.randn(2, 3, 768))[0]
assert (dropped_token_results == 0).all(), f"Some tokens not dropped: {dropped_token_results}."