mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
fix comments
This commit is contained in:
parent
380e6bff37
commit
8deeb3e191
@ -71,22 +71,19 @@ class TimesFmOutputForPrediction(BaseModelOutput):
|
|||||||
loss: Optional[Union[torch.Tensor, float]] = None
|
loss: Optional[Union[torch.Tensor, float]] = None
|
||||||
|
|
||||||
|
|
||||||
class TimesFmMLP(nn.Module):
|
class TimesFmResidualBlock(nn.Module):
|
||||||
"""Pax MLP in pytorch."""
|
"""Pax MLP in pytorch."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config: TimesFmConfig):
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
hidden_size = config.model_dim
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
|
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
|
||||||
self.down_proj = nn.Linear(intermediate_size, hidden_size)
|
self.down_proj = nn.Linear(intermediate_size, hidden_size)
|
||||||
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
|
|
||||||
|
|
||||||
def forward(self, x, paddings=None):
|
def forward(self, x, paddings=None):
|
||||||
gate_inp = self.layer_norm(x)
|
gate = self.gate_proj(x)
|
||||||
gate = self.gate_proj(gate_inp)
|
|
||||||
gate = F.relu(gate)
|
gate = F.relu(gate)
|
||||||
outputs = self.down_proj(gate)
|
outputs = self.down_proj(gate)
|
||||||
if paddings is not None:
|
if paddings is not None:
|
||||||
@ -94,41 +91,33 @@ class TimesFmMLP(nn.Module):
|
|||||||
return outputs + x
|
return outputs + x
|
||||||
|
|
||||||
|
|
||||||
class TimesFmResidualBlock(nn.Module):
|
class TimesFmMlpBlock(nn.Module):
|
||||||
"""TimesFM residual block."""
|
"""TimesFM residual block."""
|
||||||
|
|
||||||
def __init__(self, input_dims, hidden_dims, output_dims):
|
def __init__(self, input_dims, hidden_dims, output_dims):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_dims = input_dims
|
self.linear = nn.Linear(input_dims, hidden_dims)
|
||||||
self.hidden_dims = hidden_dims
|
self.activation = nn.SiLU()
|
||||||
self.output_dims = output_dims
|
|
||||||
|
|
||||||
# Hidden Layer
|
|
||||||
self.hidden_layer = nn.Sequential(
|
|
||||||
nn.Linear(input_dims, hidden_dims),
|
|
||||||
nn.SiLU(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Output Layer
|
|
||||||
self.output_layer = nn.Linear(hidden_dims, output_dims)
|
self.output_layer = nn.Linear(hidden_dims, output_dims)
|
||||||
# Residual Layer
|
|
||||||
self.residual_layer = nn.Linear(input_dims, output_dims)
|
self.residual_layer = nn.Linear(input_dims, output_dims)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
hidden = self.hidden_layer(x)
|
hidden = self.linear(x)
|
||||||
|
hidden = self.activation(hidden)
|
||||||
output = self.output_layer(hidden)
|
output = self.output_layer(hidden)
|
||||||
residual = self.residual_layer(x)
|
residual = self.residual_layer(x)
|
||||||
return output + residual
|
return output + residual
|
||||||
|
|
||||||
|
|
||||||
class TimesFmRMSNorm(nn.Module):
|
class TimesFmRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, config: TimesFmConfig):
|
||||||
"""
|
"""
|
||||||
TimesFmRMSNorm is equivalent to T5LayerNorm
|
TimesFmRMSNorm is equivalent to T5LayerNorm
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
self.weight = nn.Parameter(torch.ones(config.model_dim))
|
||||||
|
self.variance_epsilon = config.rms_norm_eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
@ -144,7 +133,7 @@ class TimesFmRMSNorm(nn.Module):
|
|||||||
class TimesFmPositionalEmbedding(nn.Module):
|
class TimesFmPositionalEmbedding(nn.Module):
|
||||||
"""Generates position embedding for a given 1-d sequence."""
|
"""Generates position embedding for a given 1-d sequence."""
|
||||||
|
|
||||||
def __init__(self, config: TimesFmConfig) -> None:
|
def __init__(self, config: TimesFmConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.min_timescale = config.min_timescale
|
self.min_timescale = config.min_timescale
|
||||||
self.max_timescale = config.max_timescale
|
self.max_timescale = config.max_timescale
|
||||||
@ -376,8 +365,12 @@ class TimesFmDecoderLayer(nn.Module):
|
|||||||
attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation]
|
attention_class = TIMESFM_ATTENTION_CLASSES[config._attn_implementation]
|
||||||
|
|
||||||
self.self_attn = attention_class(config)
|
self.self_attn = attention_class(config)
|
||||||
self.mlp = TimesFmMLP(config.model_dim, config.intermediate_size)
|
self.residual_block = TimesFmResidualBlock(config)
|
||||||
self.input_layernorm = TimesFmRMSNorm(config.model_dim, eps=config.rms_norm_eps)
|
self.rms_norm = TimesFmRMSNorm(config)
|
||||||
|
self.layer_norm = nn.LayerNorm(
|
||||||
|
normalized_shape=config.model_dim,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -390,7 +383,7 @@ class TimesFmDecoderLayer(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.rms_norm(hidden_states)
|
||||||
hidden_states, scores = self.self_attn(
|
hidden_states, scores = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -401,7 +394,8 @@ class TimesFmDecoderLayer(nn.Module):
|
|||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
# MLP
|
# MLP
|
||||||
hidden_states = self.mlp(hidden_states, paddings=paddings)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
hidden_states = self.residual_block(hidden_states, paddings=paddings)
|
||||||
|
|
||||||
return scores, hidden_states
|
return scores, hidden_states
|
||||||
|
|
||||||
@ -669,7 +663,7 @@ class TimesFmPreTrainedModel(PreTrainedModel):
|
|||||||
elif isinstance(module, TimesFmRMSNorm):
|
elif isinstance(module, TimesFmRMSNorm):
|
||||||
nn.init.zeros_(module.weight)
|
nn.init.zeros_(module.weight)
|
||||||
|
|
||||||
elif isinstance(module, TimesFmMLP):
|
elif isinstance(module, TimesFmResidualBlock):
|
||||||
# Initialize gate projection
|
# Initialize gate projection
|
||||||
module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range)
|
module.gate_proj.weight.data.normal_(mean=0, std=self.config.initializer_range)
|
||||||
if module.gate_proj.bias is not None:
|
if module.gate_proj.bias is not None:
|
||||||
@ -698,7 +692,7 @@ class TimesFmPreTrainedModel(PreTrainedModel):
|
|||||||
# Initialize scaling parameter
|
# Initialize scaling parameter
|
||||||
nn.init.ones_(module.scaling)
|
nn.init.ones_(module.scaling)
|
||||||
|
|
||||||
elif isinstance(module, TimesFmResidualBlock):
|
elif isinstance(module, TimesFmMlpBlock):
|
||||||
# Initialize hidden layer
|
# Initialize hidden layer
|
||||||
module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range)
|
module.hidden_layer[0].weight.data.normal_(mean=0, std=self.config.initializer_range)
|
||||||
if module.hidden_layer[0].bias is not None:
|
if module.hidden_layer[0].bias is not None:
|
||||||
@ -772,7 +766,7 @@ class TimesFmModel(TimesFmPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.input_ff_layer = TimesFmResidualBlock(
|
self.input_ff_layer = TimesFmMlpBlock(
|
||||||
input_dims=2 * config.patch_len,
|
input_dims=2 * config.patch_len,
|
||||||
output_dims=config.model_dim,
|
output_dims=config.model_dim,
|
||||||
hidden_dims=config.intermediate_size,
|
hidden_dims=config.intermediate_size,
|
||||||
@ -905,7 +899,7 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
|||||||
self.decoder = TimesFmModel(config)
|
self.decoder = TimesFmModel(config)
|
||||||
|
|
||||||
# quantile and mean output
|
# quantile and mean output
|
||||||
self.horizon_ff_layer = TimesFmResidualBlock(
|
self.horizon_ff_layer = TimesFmMlpBlock(
|
||||||
input_dims=config.model_dim,
|
input_dims=config.model_dim,
|
||||||
output_dims=config.horizon_len * (1 + len(config.quantiles)),
|
output_dims=config.horizon_len * (1 + len(config.quantiles)),
|
||||||
hidden_dims=config.intermediate_size,
|
hidden_dims=config.intermediate_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user