mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Disable jitter noise during evaluation in SwitchTransformers (#28077)
* Disable jitter noise during evaluation * Update outdated configuration information * Formatting * Add new line
This commit is contained in:
parent
a0522de497
commit
7c5408dade
@ -187,7 +187,7 @@ class GPTSanJapaneseTop1Router(nn.Module):
|
||||
self.input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(self.dtype)
|
||||
|
||||
if self.jitter_noise > 0:
|
||||
if self.training and self.jitter_noise > 0:
|
||||
# Multiply the token inputs by the uniform distribution - adding some noise
|
||||
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
||||
|
||||
|
@ -38,7 +38,7 @@ class SwitchTransformersConfig(PretrainedConfig):
|
||||
vocab_size (`int`, *optional*, defaults to 32128):
|
||||
Vocabulary size of the SwitchTransformers model. Defines the number of different tokens that can be
|
||||
represented by the `inputs_ids` passed when calling [`SwitchTransformersModel`].
|
||||
d_model (`int`, *optional*, defaults to 512):
|
||||
d_model (`int`, *optional*, defaults to 768):
|
||||
Size of the encoder layers and the pooler layer.
|
||||
d_kv (`int`, *optional*, defaults to 64):
|
||||
Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
|
||||
@ -50,21 +50,19 @@ class SwitchTransformersConfig(PretrainedConfig):
|
||||
Transformer.
|
||||
num_layers (`int`, *optional*, defaults to 12):
|
||||
Number of dense hidden layers in the Transformer encoder layer.
|
||||
num_sparse_encoder_layers (`int`, *optional*, defaults to 6):
|
||||
num_sparse_encoder_layers (`int`, *optional*, defaults to 3):
|
||||
Number of sparse (MoE) dense hidden layers in the Transformer encoder layer.
|
||||
num_decoder_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
|
||||
num_sparse_decoder_layers (`int`, *optional*, defaults to 12):
|
||||
num_sparse_decoder_layers (`int`, *optional*, defaults to 3):
|
||||
Number of sparse (MoE) dense hidden layers in the Transformer decoder layer.
|
||||
num_heads (`int`, *optional*, defaults to 8):
|
||||
num_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_experts (`int`, *optional*, defaults to 8):
|
||||
Number of experts for each SwitchTransformer layer.
|
||||
router_type (`str`, *optional*, defaults to `"tokens_masked"`):
|
||||
Router type - choose between `"tokens_masked", `"tokens_scatter"` and `"experts_masked"`.
|
||||
router_bias (`bool`, *optional*, defaults to `True`):
|
||||
router_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the router.
|
||||
router_jitter_noise (`float`, *optional*, defaults to 0.1):
|
||||
router_jitter_noise (`float`, *optional*, defaults to 0.01):
|
||||
Amount of noise to add to the router.
|
||||
router_dtype (`str`, *optional*, default to `"float32"`):
|
||||
The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the
|
||||
@ -83,10 +81,10 @@ class SwitchTransformersConfig(PretrainedConfig):
|
||||
The z loss factor for the total loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
initializer_factor (`float`, *optional*, defaults to 1):
|
||||
initializer_factor (`float`, *optional*, defaults to 1.0):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
testing).
|
||||
feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
|
||||
dense_act_fn (`string`, *optional*, defaults to `"relu"`):
|
||||
Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. SwitchTransformersv1.1
|
||||
uses the `"gated-gelu"` feed forward projection. Original SwitchTransformers uses `"relu"`.
|
||||
add_router_probs (`bool`, *optional*, defaults to `False`):
|
||||
|
@ -168,7 +168,7 @@ class SwitchTransformersTop1Router(nn.Module):
|
||||
self.input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(self.dtype)
|
||||
|
||||
if self.jitter_noise > 0:
|
||||
if self.training and self.jitter_noise > 0:
|
||||
# Multiply the token inputs by the uniform distribution - adding some noise
|
||||
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user