use initializer_range

This commit is contained in:
Kashif Rasul 2025-02-14 20:57:31 +01:00 committed by Jinan Zhou
parent d8c2e0d74f
commit a75b8e7a2d
2 changed files with 13 additions and 14 deletions

View File

@ -66,9 +66,8 @@ class TimesFmConfig(PretrainedConfig):
The dropout probability for the attention scores.
use_positional_embedding (`bool`, *optional*, defaults to `True`):
Whether to add positional embeddings.
initializer_factor (`float`, *optional*, defaults to 1.0):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
"""
model_type = "timesfm"
@ -97,7 +96,7 @@ class TimesFmConfig(PretrainedConfig):
pad_val: float = 1123581321.0,
attention_dropout: float = 0.0,
use_positional_embedding: bool = True,
initializer_factor: float = 1.0,
initializer_range: float = 0.02,
**kwargs,
):
self.patch_len = patch_len
@ -115,7 +114,7 @@ class TimesFmConfig(PretrainedConfig):
self.rms_norm_eps = rms_norm_eps
self.attention_dropout = attention_dropout
self.use_positional_embedding = use_positional_embedding
self.initializer_factor = initializer_factor
self.initializer_range = initializer_range
super().__init__(
is_encoder_decoder=self.is_encoder_decoder,

View File

@ -617,10 +617,10 @@ class TimesFmPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0, std=self.config.initializer_factor)
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0, std=self.config.initializer_factor)
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
@ -633,12 +633,12 @@ class TimesFmPreTrainedModel(PreTrainedModel):
elif isinstance(module, TimesFmMLP):
# Initialize gate projection
module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor)
module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.gate_proj.bias is not None:
nn.init.zeros_(module.gate_proj.bias)
# Initialize down projection
module.down_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor)
module.down_proj.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.down_proj.bias is not None:
nn.init.zeros_(module.down_proj.bias)
@ -648,12 +648,12 @@ class TimesFmPreTrainedModel(PreTrainedModel):
elif isinstance(module, TimesFmAttention):
# Initialize qkv projection
module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor)
module.qkv_proj.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.qkv_proj.bias is not None:
nn.init.zeros_(module.qkv_proj.bias)
# Initialize output projection
module.o_proj.weight.data.normal_(mean=0, std=self.config.initializer_factor)
module.o_proj.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.o_proj.bias is not None:
nn.init.zeros_(module.o_proj.bias)
@ -662,17 +662,17 @@ class TimesFmPreTrainedModel(PreTrainedModel):
elif isinstance(module, TimesFmResidualBlock):
# Initialize hidden layer
module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_factor)
module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.hidden_layer[0].bias is not None:
nn.init.zeros_(module.hidden_layer[0].bias)
# Initialize output layer
module.output_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor)
module.output_layer.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.output_layer.bias is not None:
nn.init.zeros_(module.output_layer.bias)
# Initialize residual layer
module.residual_layer.weight.data.normal_(mean=0, std=self.config.initializer_factor)
module.residual_layer.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.residual_layer.bias is not None:
nn.init.zeros_(module.residual_layer.bias)