fix comments

This commit is contained in:
Jinan Zhou 2025-02-28 11:24:17 -08:00
parent 380e6bff37
commit 8deeb3e191

View File

@ -71,22 +71,19 @@ class TimesFmOutputForPrediction(BaseModelOutput):
loss: Optional[Union[torch.Tensor, float]] = None loss: Optional[Union[torch.Tensor, float]] = None
class TimesFmMLP(nn.Module): class TimesFmResidualBlock(nn.Module):
"""Pax MLP in pytorch.""" """Pax MLP in pytorch."""
def __init__( def __init__(self, config: TimesFmConfig):
self,
hidden_size: int,
intermediate_size: int,
):
super().__init__() super().__init__()
hidden_size = config.model_dim
intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(hidden_size, intermediate_size) self.gate_proj = nn.Linear(hidden_size, intermediate_size)
self.down_proj = nn.Linear(intermediate_size, hidden_size) self.down_proj = nn.Linear(intermediate_size, hidden_size)
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
def forward(self, x, paddings=None): def forward(self, x, paddings=None):
gate_inp = self.layer_norm(x) gate = self.gate_proj(x)
gate = self.gate_proj(gate_inp)
gate = F.relu(gate) gate = F.relu(gate)
outputs = self.down_proj(gate) outputs = self.down_proj(gate)
if paddings is not None: if paddings is not None:
@ -94,41 +91,33 @@ class TimesFmMLP(nn.Module):
return outputs + x return outputs + x
class TimesFmResidualBlock(nn.Module): class TimesFmMlpBlock(nn.Module):
"""TimesFM residual block.""" """TimesFM residual block."""
def __init__(self, input_dims, hidden_dims, output_dims): def __init__(self, input_dims, hidden_dims, output_dims):
super().__init__() super().__init__()
self.input_dims = input_dims self.linear = nn.Linear(input_dims, hidden_dims)
self.hidden_dims = hidden_dims self.activation = nn.SiLU()
self.output_dims = output_dims
# Hidden Layer
self.hidden_layer = nn.Sequential(
nn.Linear(input_dims, hidden_dims),
nn.SiLU(),
)
# Output Layer
self.output_layer = nn.Linear(hidden_dims, output_dims) self.output_layer = nn.Linear(hidden_dims, output_dims)
# Residual Layer
self.residual_layer = nn.Linear(input_dims, output_dims) self.residual_layer = nn.Linear(input_dims, output_dims)
def forward(self, x): def forward(self, x):
hidden = self.hidden_layer(x) hidden = self.linear(x)
hidden = self.activation(hidden)
output = self.output_layer(hidden) output = self.output_layer(hidden)
residual = self.residual_layer(x) residual = self.residual_layer(x)
return output + residual return output + residual
class TimesFmRMSNorm(nn.Module): class TimesFmRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, config: TimesFmConfig):
""" """
TimesFmRMSNorm is equivalent to T5LayerNorm TimesFmRMSNorm is equivalent to T5LayerNorm
""" """
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.weight = nn.Parameter(torch.ones(config.model_dim))
self.variance_epsilon = config.rms_norm_eps
def forward(self, hidden_states): def forward(self, hidden_states):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
@ -144,7 +133,7 @@ class TimesFmRMSNorm(nn.Module):
class TimesFmPositionalEmbedding(nn.Module): class TimesFmPositionalEmbedding(nn.Module):
"""Generates position embedding for a given 1-d sequence.""" """Generates position embedding for a given 1-d sequence."""
def __init__(self, config: TimesFmConfig) -> None: def __init__(self, config: TimesFmConfig):
super().__init__() super().__init__()
self.min_timescale = config.min_timescale self.min_timescale = config.min_timescale
self.max_timescale = config.max_timescale self.max_timescale = config.max_timescale
@ -376,8 +365,12 @@ class TimesFmDecoderLayer(nn.Module):
attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation] attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation]
self.self_attn = attention_class(config) self.self_attn = attention_class(config)
self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size) self.residual_block = TimesFmResidualBlock(config)
self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps) self.rms_norm = TimesFmRMSNorm(config)
self.layer_norm = nn.LayerNorm(
normalized_shape=config.model_dim,
eps=config.rms_norm_eps,
)
def forward( def forward(
self, self,
@ -390,7 +383,7 @@ class TimesFmDecoderLayer(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.rms_norm(hidden_states)
hidden_states, scores = self.self_attn( hidden_states, scores = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -401,7 +394,8 @@ class TimesFmDecoderLayer(nn.Module):
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# MLP # MLP
hidden_states = self.mlp(hidden_states, paddings=paddings) hidden_states = self.layer_norm(hidden_states)
hidden_states = self.residual_block(hidden_states, paddings=paddings)
return scores, hidden_states return scores, hidden_states
@ -669,7 +663,7 @@ class TimesFmPreTrainedModel(PreTrainedModel):
elif isinstance(module, TimesFmRMSNorm): elif isinstance(module, TimesFmRMSNorm):
nn.init.zeros_(module.weight) nn.init.zeros_(module.weight)
elif isinstance(module, TimesFmMLP): elif isinstance(module, TimesFmResidualBlock):
# Initialize gate projection # Initialize gate projection
module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range) module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.gate_proj.bias is not None: if module.gate_proj.bias is not None:
@ -698,7 +692,7 @@ class TimesFmPreTrainedModel(PreTrainedModel):
# Initialize scaling parameter # Initialize scaling parameter
nn.init.ones_(module.scaling) nn.init.ones_(module.scaling)
elif isinstance(module, TimesFmResidualBlock): elif isinstance(module, TimesFmMlpBlock):
# Initialize hidden layer # Initialize hidden layer
module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range) module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.hidden_layer[0].bias is not None: if module.hidden_layer[0].bias is not None:
@ -772,7 +766,7 @@ class TimesFmModel(TimesFmPreTrainedModel):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.input_ff_layer = TimesFmResidualBlock( self.input_ff_layer = TimesFmMlpBlock(
input_dims=2 * config.patch_len, input_dims=2 * config.patch_len,
output_dims=config.model_dim, output_dims=config.model_dim,
hidden_dims=config.intermediate_size, hidden_dims=config.intermediate_size,
@ -905,7 +899,7 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
self.decoder = TimesFmModel(config) self.decoder = TimesFmModel(config)
# quantile and mean output # quantile and mean output
self.horizon_ff_layer = TimesFmResidualBlock( self.horizon_ff_layer = TimesFmMlpBlock(
input_dims=config.model_dim, input_dims=config.model_dim,
output_dims=config.horizon_len * (1 + len(config.quantiles)), output_dims=config.horizon_len * (1 + len(config.quantiles)),
hidden_dims=config.intermediate_size, hidden_dims=config.intermediate_size,