fix comments

This commit is contained in:
Jinan Zhou 2025-02-28 11:24:17 -08:00
parent 380e6bff37
commit 8deeb3e191

View File

@ -71,22 +71,19 @@ class TimesFmOutputForPrediction(BaseModelOutput):
loss: Optional[Union[torch.Tensor, float]] = None
class TimesFmMLP(nn.Module):
class TimesFmResidualBlock(nn.Module):
"""Pax MLP in pytorch."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
):
def __init__(self, config: TimesFmConfig):
super().__init__()
hidden_size = config.model_dim
intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
self.down_proj = nn.Linear(intermediate_size, hidden_size)
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
def forward(self, x, paddings=None):
gate_inp = self.layer_norm(x)
gate = self.gate_proj(gate_inp)
gate = self.gate_proj(x)
gate = F.relu(gate)
outputs = self.down_proj(gate)
if paddings is not None:
@ -94,41 +91,33 @@ class TimesFmMLP(nn.Module):
return outputs + x
class TimesFmResidualBlock(nn.Module):
class TimesFmMlpBlock(nn.Module):
"""TimesFM residual block."""
def __init__(self, input_dims, hidden_dims, output_dims):
super().__init__()
self.input_dims = input_dims
self.hidden_dims = hidden_dims
self.output_dims = output_dims
# Hidden Layer
self.hidden_layer = nn.Sequential(
nn.Linear(input_dims, hidden_dims),
nn.SiLU(),
)
# Output Layer
self.linear = nn.Linear(input_dims, hidden_dims)
self.activation = nn.SiLU()
self.output_layer = nn.Linear(hidden_dims, output_dims)
# Residual Layer
self.residual_layer = nn.Linear(input_dims, output_dims)
def forward(self, x):
hidden = self.hidden_layer(x)
hidden = self.linear(x)
hidden = self.activation(hidden)
output = self.output_layer(hidden)
residual = self.residual_layer(x)
return output + residual
class TimesFmRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, config: TimesFmConfig):
"""
TimesFmRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.weight = nn.Parameter(torch.ones(config.model_dim))
self.variance_epsilon = config.rms_norm_eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
@ -144,7 +133,7 @@ class TimesFmRMSNorm(nn.Module):
class TimesFmPositionalEmbedding(nn.Module):
"""Generates position embedding for a given 1-d sequence."""
def __init__(self, config: TimesFmConfig) -> None:
def __init__(self, config: TimesFmConfig):
super().__init__()
self.min_timescale = config.min_timescale
self.max_timescale = config.max_timescale
@ -376,8 +365,12 @@ class TimesFmDecoderLayer(nn.Module):
attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation]
self.self_attn = attention_class(config)
self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size)
self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps)
self.residual_block = TimesFmResidualBlock(config)
self.rms_norm = TimesFmRMSNorm(config)
self.layer_norm = nn.LayerNorm(
normalized_shape=config.model_dim,
eps=config.rms_norm_eps,
)
def forward(
self,
@ -390,7 +383,7 @@ class TimesFmDecoderLayer(nn.Module):
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.rms_norm(hidden_states)
hidden_states, scores = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
@ -401,7 +394,8 @@ class TimesFmDecoderLayer(nn.Module):
hidden_states = residual + hidden_states
# MLP
hidden_states = self.mlp(hidden_states, paddings=paddings)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.residual_block(hidden_states, paddings=paddings)
return scores, hidden_states
@ -669,7 +663,7 @@ class TimesFmPreTrainedModel(PreTrainedModel):
elif isinstance(module, TimesFmRMSNorm):
nn.init.zeros_(module.weight)
elif isinstance(module, TimesFmMLP):
elif isinstance(module, TimesFmResidualBlock):
# Initialize gate projection
module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.gate_proj.bias is not None:
@ -698,7 +692,7 @@ class TimesFmPreTrainedModel(PreTrainedModel):
# Initialize scaling parameter
nn.init.ones_(module.scaling)
elif isinstance(module, TimesFmResidualBlock):
elif isinstance(module, TimesFmMlpBlock):
# Initialize hidden layer
module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.hidden_layer[0].bias is not None:
@ -772,7 +766,7 @@ class TimesFmModel(TimesFmPreTrainedModel):
super().__init__(config)
self.config = config
self.input_ff_layer = TimesFmResidualBlock(
self.input_ff_layer = TimesFmMlpBlock(
input_dims=2 * config.patch_len,
output_dims=config.model_dim,
hidden_dims=config.intermediate_size,
@ -905,7 +899,7 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
self.decoder = TimesFmModel(config)
# quantile and mean output
self.horizon_ff_layer = TimesFmResidualBlock(
self.horizon_ff_layer = TimesFmMlpBlock(
input_dims=config.model_dim,
output_dims=config.horizon_len * (1 + len(config.quantiles)),
hidden_dims=config.intermediate_size,