mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Correctly drop tokens in SwitchTransformer (#37123)
Previously, the identity function was used for dropped tokens with a weight from the expert that was not applied to the hidden states. This was misleading, because dropping means, the expert weight is zero. Instead of trying to fix the weight, we take an easier approach by initializing with zeros. Fixes issue https://github.com/huggingface/transformers/issues/37017
This commit is contained in:
parent
7ecc5b88c0
commit
bde41d69b4
@ -301,10 +301,8 @@ class SwitchTransformersSparseMLP(nn.Module):
|
||||
router_mask, router_probs, router_logits = self.router(hidden_states)
|
||||
expert_index = torch.argmax(router_mask, dim=-1)
|
||||
|
||||
# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
|
||||
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the selected ones.
|
||||
|
||||
next_states = hidden_states.clone()
|
||||
# If a token gets dropped, we just set it to zero such that it does not get updated.
|
||||
next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
|
||||
router_mask = router_mask.bool()
|
||||
batch_size, seq_len, num_experts = router_mask.shape
|
||||
|
@ -42,6 +42,7 @@ if is_torch_available():
|
||||
SwitchTransformersEncoderModel,
|
||||
SwitchTransformersForConditionalGeneration,
|
||||
SwitchTransformersModel,
|
||||
SwitchTransformersSparseMLP,
|
||||
SwitchTransformersTop1Router,
|
||||
)
|
||||
from transformers.models.switch_transformers.modeling_switch_transformers import (
|
||||
@ -1133,3 +1134,16 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
for i in range(0, BATCH_SIZE, 2):
|
||||
self.assertEqual(batch_output[i], batch_output[i + 1])
|
||||
|
||||
|
||||
@require_torch
|
||||
class SwitchTransformersSparseMLPTests(unittest.TestCase):
|
||||
def test_token_dropping(self):
|
||||
r"""
|
||||
This test checks if the token dropping actually drops tokens.
|
||||
"""
|
||||
config = SwitchTransformersConfig(expert_capacity=0) # we drop everything
|
||||
moe = SwitchTransformersSparseMLP(config)
|
||||
dropped_token_results = moe(torch.randn(2, 3, 768))[0]
|
||||
|
||||
assert (dropped_token_results == 0).all(), f"Some tokens not dropped: {dropped_token_results}."
|
||||
|
Loading…
Reference in New Issue
Block a user