Simplify the implementation of jitter noise in moe models (#27643)

This commit is contained in:
Wangyi Jiang 2023-11-22 18:49:40 +08:00 committed by GitHub
parent b54993aa94
commit c651eb23c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 20 deletions

View File

@ -188,17 +188,8 @@ class GPTSanJapaneseTop1Router(nn.Module):
hidden_states = hidden_states.to(self.dtype)
if self.jitter_noise > 0:
# Get the lower and upper bound of the uniform distribution
# Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch
distrib_lower_bound = 1.0 - self.jitter_noise
distrib_upper_bound = 1.0 + self.jitter_noise
uniform_distrib = torch.rand(hidden_states.shape, device=hidden_states.device, dtype=self.dtype)
uniform_distrib = uniform_distrib * (distrib_lower_bound - distrib_upper_bound)
uniform_distrib = uniform_distrib + distrib_upper_bound
# Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= uniform_distrib
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
# Shape: [num_groups, tokens_per_group, num_experts]
self._cast_classifier()

View File

@ -169,17 +169,8 @@ class SwitchTransformersTop1Router(nn.Module):
hidden_states = hidden_states.to(self.dtype)
if self.jitter_noise > 0:
# Get the lower and upper bound of the uniform distribution
# Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch
distrib_lower_bound = 1.0 - self.jitter_noise
distrib_upper_bound = 1.0 + self.jitter_noise
uniform_distrib = torch.rand(hidden_states.shape, device=hidden_states.device, dtype=self.dtype)
uniform_distrib = uniform_distrib * (distrib_lower_bound - distrib_upper_bound)
uniform_distrib = uniform_distrib + distrib_upper_bound
# Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= uniform_distrib
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
# Shape: [num_groups, tokens_per_group, num_experts]
self._cast_classifier()