From 8deeb3e191b3671bc1d74dbfe77b736a066c3d34 Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Fri, 28 Feb 2025 11:24:17 -0800 Subject: [PATCH] fix comments --- .../models/timesfm/modeling_timesfm.py | 64 +++++++++---------- 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 0219cabdd40..fa4526525f6 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -71,22 +71,19 @@ class TimesFmOutputForPrediction(BaseModelOutput): loss: Optional[Union[torch.Tensor, float]] = None -class TimesFmMLP(nn.Module): +class TimesFmResidualBlock(nn.Module): """Pax MLP in pytorch.""" - def __init__( - self, - hidden_size: int, - intermediate_size: int, - ): + def __init__(self, config: TimesFmConfig): super().__init__() + hidden_size = config.model_dim + intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(hidden_size, intermediate_size) self.down_proj = nn.Linear(intermediate_size, hidden_size) - self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) def forward(self, x, paddings=None): - gate_inp = self.layer_norm(x) - gate = self.gate_proj(gate_inp) + gate = self.gate_proj(x) gate = F.relu(gate) outputs = self.down_proj(gate) if paddings is not None: @@ -94,41 +91,33 @@ class TimesFmMLP(nn.Module): return outputs + x -class TimesFmResidualBlock(nn.Module): +class TimesFmMlpBlock(nn.Module): """TimesFM residual block.""" def __init__(self, input_dims, hidden_dims, output_dims): super().__init__() - self.input_dims = input_dims - self.hidden_dims = hidden_dims - self.output_dims = output_dims - - # Hidden Layer - self.hidden_layer = nn.Sequential( - nn.Linear(input_dims, hidden_dims), - nn.SiLU(), - ) - - # Output Layer + self.linear = nn.Linear(input_dims, hidden_dims) + self.activation = nn.SiLU() self.output_layer = nn.Linear(hidden_dims, output_dims) - # Residual Layer self.residual_layer = nn.Linear(input_dims, output_dims) def forward(self, x): - hidden = self.hidden_layer(x) + hidden = self.linear(x) + hidden = self.activation(hidden) output = self.output_layer(hidden) residual = self.residual_layer(x) return output + residual class TimesFmRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, config: TimesFmConfig): """ TimesFmRMSNorm is equivalent to T5LayerNorm """ super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps + + self.weight = nn.Parameter(torch.ones(config.model_dim)) + self.variance_epsilon = config.rms_norm_eps def forward(self, hidden_states): input_dtype = hidden_states.dtype @@ -144,7 +133,7 @@ class TimesFmRMSNorm(nn.Module): class TimesFmPositionalEmbedding(nn.Module): """Generates position embedding for a given 1-d sequence.""" - def __init__(self, config: TimesFmConfig) -> None: + def __init__(self, config: TimesFmConfig): super().__init__() self.min_timescale = config.min_timescale self.max_timescale = config.max_timescale @@ -376,8 +365,12 @@ class TimesFmDecoderLayer(nn.Module): attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] self.self_attn = attention_class(config) - self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size) - self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps) + self.residual_block = TimesFmResidualBlock(config) + self.rms_norm = TimesFmRMSNorm(config) + self.layer_norm = nn.LayerNorm( + normalized_shape=config.model_dim, + eps=config.rms_norm_eps, + ) def forward( self, @@ -390,7 +383,7 @@ class TimesFmDecoderLayer(nn.Module): ) -> torch.Tensor: # Self Attention residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.rms_norm(hidden_states) hidden_states, scores = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, @@ -401,7 +394,8 @@ class TimesFmDecoderLayer(nn.Module): hidden_states = residual + hidden_states # MLP - hidden_states = self.mlp(hidden_states, paddings=paddings) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.residual_block(hidden_states, paddings=paddings) return scores, hidden_states @@ -669,7 +663,7 @@ class TimesFmPreTrainedModel(PreTrainedModel): elif isinstance(module, TimesFmRMSNorm): nn.init.zeros_(module.weight) - elif isinstance(module, TimesFmMLP): + elif isinstance(module, TimesFmResidualBlock): # Initialize gate projection module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) if module.gate_proj.bias is not None: @@ -698,7 +692,7 @@ class TimesFmPreTrainedModel(PreTrainedModel): # Initialize scaling parameter nn.init.ones_(module.scaling) - elif isinstance(module, TimesFmResidualBlock): + elif isinstance(module, TimesFmMlpBlock): # Initialize hidden layer module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range) if module.hidden_layer[0].bias is not None: @@ -772,7 +766,7 @@ class TimesFmModel(TimesFmPreTrainedModel): super().__init__(config) self.config = config - self.input_ff_layer = TimesFmResidualBlock( + self.input_ff_layer = TimesFmMlpBlock( input_dims=2 * config.patch_len, output_dims=config.model_dim, hidden_dims=config.intermediate_size, @@ -905,7 +899,7 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel): self.decoder = TimesFmModel(config) # quantile and mean output - self.horizon_ff_layer = TimesFmResidualBlock( + self.horizon_ff_layer = TimesFmMlpBlock( input_dims=config.model_dim, output_dims=config.horizon_len * (1 + len(config.quantiles)), hidden_dims=config.intermediate_size,