mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Simplify the implementation of jitter noise in moe models (#27643)
This commit is contained in:
parent
b54993aa94
commit
c651eb23c3
@ -188,17 +188,8 @@ class GPTSanJapaneseTop1Router(nn.Module):
|
||||
hidden_states = hidden_states.to(self.dtype)
|
||||
|
||||
if self.jitter_noise > 0:
|
||||
# Get the lower and upper bound of the uniform distribution
|
||||
# Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch
|
||||
distrib_lower_bound = 1.0 - self.jitter_noise
|
||||
distrib_upper_bound = 1.0 + self.jitter_noise
|
||||
|
||||
uniform_distrib = torch.rand(hidden_states.shape, device=hidden_states.device, dtype=self.dtype)
|
||||
uniform_distrib = uniform_distrib * (distrib_lower_bound - distrib_upper_bound)
|
||||
|
||||
uniform_distrib = uniform_distrib + distrib_upper_bound
|
||||
# Multiply the token inputs by the uniform distribution - adding some noise
|
||||
hidden_states *= uniform_distrib
|
||||
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
||||
|
||||
# Shape: [num_groups, tokens_per_group, num_experts]
|
||||
self._cast_classifier()
|
||||
|
@ -169,17 +169,8 @@ class SwitchTransformersTop1Router(nn.Module):
|
||||
hidden_states = hidden_states.to(self.dtype)
|
||||
|
||||
if self.jitter_noise > 0:
|
||||
# Get the lower and upper bound of the uniform distribution
|
||||
# Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch
|
||||
distrib_lower_bound = 1.0 - self.jitter_noise
|
||||
distrib_upper_bound = 1.0 + self.jitter_noise
|
||||
|
||||
uniform_distrib = torch.rand(hidden_states.shape, device=hidden_states.device, dtype=self.dtype)
|
||||
uniform_distrib = uniform_distrib * (distrib_lower_bound - distrib_upper_bound)
|
||||
|
||||
uniform_distrib = uniform_distrib + distrib_upper_bound
|
||||
# Multiply the token inputs by the uniform distribution - adding some noise
|
||||
hidden_states *= uniform_distrib
|
||||
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
||||
|
||||
# Shape: [num_groups, tokens_per_group, num_experts]
|
||||
self._cast_classifier()
|
||||
|
Loading…
Reference in New Issue
Block a user