mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
fix comments
This commit is contained in:
parent
380e6bff37
commit
8deeb3e191
@ -71,22 +71,19 @@ class TimesFmOutputForPrediction(BaseModelOutput):
|
||||
loss: Optional[Union[torch.Tensor, float]] = None
|
||||
|
||||
|
||||
class TimesFmMLP(nn.Module):
|
||||
class TimesFmResidualBlock(nn.Module):
|
||||
"""Pax MLP in pytorch."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
):
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.model_dim
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
|
||||
|
||||
def forward(self, x, paddings=None):
|
||||
gate_inp = self.layer_norm(x)
|
||||
gate = self.gate_proj(gate_inp)
|
||||
gate = self.gate_proj(x)
|
||||
gate = F.relu(gate)
|
||||
outputs = self.down_proj(gate)
|
||||
if paddings is not None:
|
||||
@ -94,41 +91,33 @@ class TimesFmMLP(nn.Module):
|
||||
return outputs + x
|
||||
|
||||
|
||||
class TimesFmResidualBlock(nn.Module):
|
||||
class TimesFmMlpBlock(nn.Module):
|
||||
"""TimesFM residual block."""
|
||||
|
||||
def __init__(self, input_dims, hidden_dims, output_dims):
|
||||
super().__init__()
|
||||
self.input_dims = input_dims
|
||||
self.hidden_dims = hidden_dims
|
||||
self.output_dims = output_dims
|
||||
|
||||
# Hidden Layer
|
||||
self.hidden_layer = nn.Sequential(
|
||||
nn.Linear(input_dims, hidden_dims),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
# Output Layer
|
||||
self.linear = nn.Linear(input_dims, hidden_dims)
|
||||
self.activation = nn.SiLU()
|
||||
self.output_layer = nn.Linear(hidden_dims, output_dims)
|
||||
# Residual Layer
|
||||
self.residual_layer = nn.Linear(input_dims, output_dims)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.hidden_layer(x)
|
||||
hidden = self.linear(x)
|
||||
hidden = self.activation(hidden)
|
||||
output = self.output_layer(hidden)
|
||||
residual = self.residual_layer(x)
|
||||
return output + residual
|
||||
|
||||
|
||||
class TimesFmRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
"""
|
||||
TimesFmRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(config.model_dim))
|
||||
self.variance_epsilon = config.rms_norm_eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
@ -144,7 +133,7 @@ class TimesFmRMSNorm(nn.Module):
|
||||
class TimesFmPositionalEmbedding(nn.Module):
|
||||
"""Generates position embedding for a given 1-d sequence."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig) -> None:
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__()
|
||||
self.min_timescale = config.min_timescale
|
||||
self.max_timescale = config.max_timescale
|
||||
@ -376,8 +365,12 @@ class TimesFmDecoderLayer(nn.Module):
|
||||
attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation]
|
||||
|
||||
self.self_attn = attention_class(config)
|
||||
self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size)
|
||||
self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps)
|
||||
self.residual_block = TimesFmResidualBlock(config)
|
||||
self.rms_norm = TimesFmRMSNorm(config)
|
||||
self.layer_norm = nn.LayerNorm(
|
||||
normalized_shape=config.model_dim,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -390,7 +383,7 @@ class TimesFmDecoderLayer(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.rms_norm(hidden_states)
|
||||
hidden_states, scores = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
@ -401,7 +394,8 @@ class TimesFmDecoderLayer(nn.Module):
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# MLP
|
||||
hidden_states = self.mlp(hidden_states, paddings=paddings)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.residual_block(hidden_states, paddings=paddings)
|
||||
|
||||
return scores, hidden_states
|
||||
|
||||
@ -669,7 +663,7 @@ class TimesFmPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, TimesFmRMSNorm):
|
||||
nn.init.zeros_(module.weight)
|
||||
|
||||
elif isinstance(module, TimesFmMLP):
|
||||
elif isinstance(module, TimesFmResidualBlock):
|
||||
# Initialize gate projection
|
||||
module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range)
|
||||
if module.gate_proj.bias is not None:
|
||||
@ -698,7 +692,7 @@ class TimesFmPreTrainedModel(PreTrainedModel):
|
||||
# Initialize scaling parameter
|
||||
nn.init.ones_(module.scaling)
|
||||
|
||||
elif isinstance(module, TimesFmResidualBlock):
|
||||
elif isinstance(module, TimesFmMlpBlock):
|
||||
# Initialize hidden layer
|
||||
module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range)
|
||||
if module.hidden_layer[0].bias is not None:
|
||||
@ -772,7 +766,7 @@ class TimesFmModel(TimesFmPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.input_ff_layer = TimesFmResidualBlock(
|
||||
self.input_ff_layer = TimesFmMlpBlock(
|
||||
input_dims=2 * config.patch_len,
|
||||
output_dims=config.model_dim,
|
||||
hidden_dims=config.intermediate_size,
|
||||
@ -905,7 +899,7 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
||||
self.decoder = TimesFmModel(config)
|
||||
|
||||
# quantile and mean output
|
||||
self.horizon_ff_layer = TimesFmResidualBlock(
|
||||
self.horizon_ff_layer = TimesFmMlpBlock(
|
||||
input_dims=config.model_dim,
|
||||
output_dims=config.horizon_len * (1 + len(config.quantiles)),
|
||||
hidden_dims=config.intermediate_size,
|
||||
|
Loading…
Reference in New Issue
Block a user