mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add TimesFM Time Series Forecasting Model (#34082)
* initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * initial copy from t5 * added config and attention layers * add TimesFMPositionalEmbedding * calcuate scale_factor once * add more configs and TimesFMResidualBlock * fix input_dims * standardize code format with black * remove unneeded modules * TimesFM Model * order of imports * copy from Google official implementation * remove covariate forecasting * Adapting TimesFM to HF format * restructing in progress * adapted to HF convention * timesfm test * the model runs * fixing unit tests * fixing unit tests in progress * add post_init * do not change TimesFMOutput * fixing unit tests * all unit tests passed * remove timesfm_layers * add intermediate_size and initialize with config * initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * add _CHECKPOINT_FOR_DOC * fix comments * Revert "fix comments" This reverts commit8deeb3e191
. * add _prepare_4d_attention_mask * we do not have generative model classes * use Cache * return past_key_values * modules initialized with config only * update year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add layer_idx to cache * modular timesfm * fix test * unwrap sequential class * fix toctree * remove TimesFmOnnxConfig * fix modular * remove TimesFmStackedDecoder * split qkv layer into individual layers * rename projection layers * use ALL_ATTENTION_FUNCTIONS * is_causal is True * rename config * does not support flash_attn_2 * formatting * fix typo in docsstring * rename inputs * add time series mapping * Update src/transformers/models/olmo2/modeling_olmo2.py * Update src/transformers/models/moonshine/modeling_moonshine.py * use updated arguments * fix class name * add MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING * isort * consolidate _preprocess into forward * fix a typo * fix a typo * fix toc * fix modular * remove aaserts * use self.config._attn_implementation * move to _postprocess_output * remove timesfm_get_large_negative_number * use view unstead of multiple unsqueeze * make helpers static methods of the Model * use to_tuple * use to_tuple if not return_dict * remove unused intitialization block as its incorporated in nn.Linear * remove unused num_key_value_groups * use the same convention as the masking method * update modular * do not use unsqueeze * use view instead of unsqueeze * use buffer for inv_timescales * formatting * modular conversion * remove unneeded intialization * add missing docstrings * remove cache * use simple_eager_attention_forward * support tp_plan * support for flex and flash attention masks * Revert "support for flex and flash attention masks" This reverts commitdef36c4fcf
. * fix device * fix tests on gpu * remove unsued large model test * removed unneeded comments * add example usage * fix style * add import * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * inherit from LlamaRMSNorm * use can_return_tuple decorator * remvoe return_dict * fix year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * pretrained does not inherit from GenerationMixin * use model for integration test --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Rajat Sen <rsen91@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
parent
8669c016d2
commit
a91020aed0
@ -1059,6 +1059,8 @@
|
||||
title: PatchTST
|
||||
- local: model_doc/time_series_transformer
|
||||
title: Time Series Transformer
|
||||
- local: model_doc/timesfm
|
||||
title: TimesFM
|
||||
title: Time series models
|
||||
- sections:
|
||||
- local: model_doc/graphormer
|
||||
|
88
docs/source/en/model_doc/timesfm.md
Normal file
88
docs/source/en/model_doc/timesfm.md
Normal file
@ -0,0 +1,88 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# TimesFM
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model proposed in [A decoder-only foundation model for time-series forecasting](https://huggingface.co/papers/2310.10688) by Abhimanyu Das, Weihao Kong, Rajat Sen, and Yichen Zhou. It is a decoder only model that uses non-overlapping patches of time-series data as input and outputs some output patch length prediction in an autoregressive fashion.
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Motivated by recent advances in large language models for Natural Language Processing (NLP), we design a time-series foundation model for forecasting whose out-of-the-box zero-shot performance on a variety of public datasets comes close to the accuracy of state-of-the-art supervised forecasting models for each individual dataset. Our model is based on pretraining a patched-decoder style attention model on a large time-series corpus, and can work well across different forecasting history lengths, prediction lengths and temporal granularities.*
|
||||
|
||||
|
||||
This model was contributed by [kashif](https://huggingface.co/kashif).
|
||||
The original code can be found [here](https://github.com/google-research/timesfm).
|
||||
|
||||
|
||||
To use the model:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import TimesFmModelForPrediction
|
||||
|
||||
|
||||
model = TimesFmModelForPrediction.from_pretrained(
|
||||
"google/timesfm-2.0-500m-pytorch",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa",
|
||||
device_map="cuda" if torch.cuda.is_available() else None
|
||||
)
|
||||
|
||||
|
||||
# Create dummy inputs
|
||||
forecast_input = [
|
||||
np.sin(np.linspace(0, 20, 100)),
|
||||
np.sin(np.linspace(0, 20, 200)),
|
||||
np.sin(np.linspace(0, 20, 400)),
|
||||
]
|
||||
frequency_input = [0, 1, 2]
|
||||
|
||||
# Convert inputs to sequence of tensors
|
||||
forecast_input_tensor = [
|
||||
torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for ts in forecast_input
|
||||
]
|
||||
frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Get predictions from the pre-trained model
|
||||
with torch.no_grad():
|
||||
outputs = model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True)
|
||||
point_forecast_conv = outputs.mean_predictions.float().cpu().numpy()
|
||||
quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy()
|
||||
```
|
||||
|
||||
## TimesFmConfig
|
||||
|
||||
[[autodoc]] TimesFmConfig
|
||||
|
||||
## TimesFmModel
|
||||
|
||||
[[autodoc]] TimesFmModel
|
||||
- forward
|
||||
|
||||
## TimesFmModelForPrediction
|
||||
|
||||
[[autodoc]] TimesFmModelForPrediction
|
||||
- forward
|
@ -279,6 +279,7 @@ if TYPE_CHECKING:
|
||||
from .tapas import *
|
||||
from .textnet import *
|
||||
from .time_series_transformer import *
|
||||
from .timesfm import *
|
||||
from .timesformer import *
|
||||
from .timm_backbone import *
|
||||
from .timm_wrapper import *
|
||||
|
@ -313,6 +313,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("tapas", "TapasConfig"),
|
||||
("textnet", "TextNetConfig"),
|
||||
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
||||
("timesfm", "TimesFmConfig"),
|
||||
("timesformer", "TimesformerConfig"),
|
||||
("timm_backbone", "TimmBackboneConfig"),
|
||||
("timm_wrapper", "TimmWrapperConfig"),
|
||||
@ -681,6 +682,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("tapex", "TAPEX"),
|
||||
("textnet", "TextNet"),
|
||||
("time_series_transformer", "Time Series Transformer"),
|
||||
("timesfm", "TimesFm"),
|
||||
("timesformer", "TimeSformer"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
("timm_wrapper", "TimmWrapperModel"),
|
||||
|
@ -281,6 +281,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("tapas", "TapasModel"),
|
||||
("textnet", "TextNetModel"),
|
||||
("time_series_transformer", "TimeSeriesTransformerModel"),
|
||||
("timesfm", "TimesFmModel"),
|
||||
("timesformer", "TimesformerModel"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
("timm_wrapper", "TimmWrapperModel"),
|
||||
@ -1542,6 +1543,12 @@ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("timesfm", "TimesFmModelForPrediction"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("swin2sr", "Swin2SRForImageSuperResolution"),
|
||||
@ -1650,6 +1657,10 @@ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
|
||||
|
||||
|
||||
@ -1820,6 +1831,15 @@ AutoModelForSemanticSegmentation = auto_class_update(
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForTimeSeriesPrediction(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING
|
||||
|
||||
|
||||
AutoModelForTimeSeriesPrediction = auto_class_update(
|
||||
AutoModelForTimeSeriesPrediction, head_doc="time-series prediction"
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
|
||||
|
||||
@ -1994,6 +2014,7 @@ __all__ = [
|
||||
"MODEL_FOR_TEXT_ENCODING_MAPPING",
|
||||
"MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
|
||||
"MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
|
||||
"MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING",
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
||||
|
27
src/transformers/models/timesfm/__init__.py
Normal file
27
src/transformers/models/timesfm/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_timesfm import *
|
||||
from .modeling_timesfm import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
129
src/transformers/models/timesfm/configuration_timesfm.py
Normal file
129
src/transformers/models/timesfm/configuration_timesfm.py
Normal file
@ -0,0 +1,129 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google LLC and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TimesFM model configuration"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TimesFmConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to
|
||||
instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the TimesFM
|
||||
[google/timesfm-2.0-500m-pytorch](https://huggingface.co/google/timesfm-2.0-500m-pytorch) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Arguments:
|
||||
patch_length (`int`, *optional*, defaults to 32):
|
||||
The length of one patch in the input sequence.
|
||||
context_length (`int`, *optional*, defaults to 512):
|
||||
The length of the input context.
|
||||
horizon_length (`int`, *optional*, defaults to 128):
|
||||
The length of the prediction horizon.
|
||||
freq_size (`int`, *optional*, defaults to 3):
|
||||
The number of frequency embeddings.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 50):
|
||||
Number of Transformer layers.
|
||||
hidden_size (`int`, *optional*, defaults to 1280):
|
||||
Size of the hidden layers in the feed-forward networks.
|
||||
intermediate_size (`int`, *optional*, defaults to 1280):
|
||||
Dimension of the MLP representations.
|
||||
head_dim (`int`, *optional*, defaults to 80):
|
||||
Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
|
||||
be defined as `num_attention_heads * head_dim`.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
tolerance (`float`, *optional*, defaults to 1e-06):
|
||||
The tolerance for the quantile loss.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the RMS normalization layers.
|
||||
quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]`):
|
||||
The quantiles to predict.
|
||||
pad_val (`float`, *optional*, defaults to 1123581321.0):
|
||||
The value used to pad the predictions.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for the attention scores.
|
||||
use_positional_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add positional embeddings.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
min_timescale (`int`, *optional*, defaults to 1):
|
||||
The start of the geometric positional index. Determines the periodicity of
|
||||
the added signal.
|
||||
max_timescale (`int`, *optional*, defaults to 10000):
|
||||
The end of the geometric positional index. Determines the frequency of the
|
||||
added signal.
|
||||
"""
|
||||
|
||||
model_type = "timesfm"
|
||||
keys_to_ignore_at_inference = []
|
||||
is_encoder_decoder = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_length: int = 32,
|
||||
context_length: int = 512,
|
||||
horizon_length: int = 128,
|
||||
freq_size: int = 3,
|
||||
num_hidden_layers: int = 50,
|
||||
hidden_size: int = 1280,
|
||||
intermediate_size: int = 1280,
|
||||
head_dim: int = 80,
|
||||
num_attention_heads: int = 16,
|
||||
tolerance: float = 1e-6,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
pad_val: float = 1123581321.0,
|
||||
attention_dropout: float = 0.0,
|
||||
use_positional_embedding: bool = False,
|
||||
initializer_range: float = 0.02,
|
||||
min_timescale: int = 1,
|
||||
max_timescale: int = 10_000,
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_length = patch_length
|
||||
self.context_length = context_length
|
||||
self.horizon_length = horizon_length
|
||||
self.quantiles = quantiles
|
||||
self.pad_val = pad_val
|
||||
self.freq_size = freq_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.head_dim = head_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.tolerance = tolerance
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.attention_dropout = attention_dropout
|
||||
self.use_positional_embedding = use_positional_embedding
|
||||
self.initializer_range = initializer_range
|
||||
self.min_timescale = min_timescale
|
||||
self.max_timescale = max_timescale
|
||||
|
||||
super().__init__(
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["TimesFmConfig"]
|
275
src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py
Normal file
275
src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py
Normal file
@ -0,0 +1,275 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import timesfm
|
||||
import torch
|
||||
|
||||
from transformers import TimesFmConfig, TimesFmModelForPrediction
|
||||
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py \
|
||||
--output_dir /output/path
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def get_nested_attr(obj, key):
|
||||
"""Recursively retrieves an attribute from an object, handling list/tuple indexing if present."""
|
||||
parts = key.split(".")
|
||||
for part in parts:
|
||||
match = re.match(r"(.*)\[(\d+)\]", part) # Handle list indexing like `layers[0]`
|
||||
if match:
|
||||
attr_name, index = match.groups()
|
||||
obj = getattr(obj, attr_name)[int(index)] # Access list/tuple element
|
||||
else:
|
||||
obj = getattr(obj, part) # Regular attribute access
|
||||
return obj
|
||||
|
||||
|
||||
def write_model(model_path, safe_serialization=True, huggingface_repo_id="google/timesfm-2.0-500m-pytorch"):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
tmp_model_path = os.path.join(model_path, "tmp")
|
||||
os.makedirs(tmp_model_path, exist_ok=True)
|
||||
|
||||
tfm = timesfm.TimesFm(
|
||||
hparams=timesfm.TimesFmHparams(
|
||||
backend="cuda" if torch.cuda.is_available() else "cpu",
|
||||
per_core_batch_size=32,
|
||||
horizon_len=128,
|
||||
input_patch_len=32,
|
||||
output_patch_len=128,
|
||||
num_layers=50,
|
||||
model_dims=1280,
|
||||
use_positional_embedding=False,
|
||||
),
|
||||
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
|
||||
)
|
||||
|
||||
timesfm_config = TimesFmConfig(
|
||||
patch_length=tfm.hparams.input_patch_len,
|
||||
context_length=tfm.hparams.context_len,
|
||||
horizon_length=tfm.hparams.horizon_len,
|
||||
num_hidden_layers=tfm.hparams.num_layers,
|
||||
hidden_size=tfm.hparams.model_dims,
|
||||
intermediate_size=tfm.hparams.model_dims,
|
||||
head_dim=tfm.hparams.model_dims // tfm.hparams.num_heads,
|
||||
num_attention_heads=tfm.hparams.num_heads,
|
||||
use_positional_embedding=tfm.hparams.use_positional_embedding,
|
||||
)
|
||||
timesfm_config.save_pretrained(tmp_model_path)
|
||||
timesfm_model = TimesFmModelForPrediction(timesfm_config)
|
||||
|
||||
# copy the weights from the original model to the new model making
|
||||
original_model = tfm._model
|
||||
|
||||
# mapping of the layers from the original model to the transformer model
|
||||
MODEL_LAYER_MAPPING = {
|
||||
"input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.input_layer.weight",
|
||||
"input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.input_layer.bias",
|
||||
"input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight",
|
||||
"input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias",
|
||||
"input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight",
|
||||
"input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias",
|
||||
"freq_emb.weight": "decoder.freq_emb.weight",
|
||||
"horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.input_layer.weight",
|
||||
"horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.input_layer.bias",
|
||||
"horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight",
|
||||
"horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias",
|
||||
"horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight",
|
||||
"horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias",
|
||||
}
|
||||
|
||||
TRANSFORMER_LAYER_MAPPING = {
|
||||
"stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.layers[{i}].self_attn.qkv_proj.weight",
|
||||
"stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.layers[{i}].self_attn.qkv_proj.bias",
|
||||
"stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.layers[{i}].self_attn.o_proj.weight",
|
||||
"stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.layers[{i}].self_attn.o_proj.bias",
|
||||
"stacked_transformer.layers[{i}].self_attn.scaling": "decoder.layers[{i}].self_attn.scaling",
|
||||
"stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.layers[{i}].mlp.gate_proj.weight",
|
||||
"stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.layers[{i}].mlp.gate_proj.bias",
|
||||
"stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.layers[{i}].mlp.down_proj.weight",
|
||||
"stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.layers[{i}].mlp.down_proj.bias",
|
||||
"stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.layers[{i}].mlp.layer_norm.weight",
|
||||
"stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.layers[{i}].mlp.layer_norm.bias",
|
||||
"stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.layers[{i}].input_layernorm.weight",
|
||||
}
|
||||
|
||||
for old_key, new_key in MODEL_LAYER_MAPPING.items():
|
||||
try:
|
||||
old_attr = get_nested_attr(original_model, old_key) # Get tensor from original model
|
||||
new_attr = get_nested_attr(timesfm_model, new_key) # Get corresponding attribute in new model
|
||||
new_attr.data.copy_(old_attr.data) # Copy data
|
||||
except AttributeError:
|
||||
print(f"Skipping {old_key} (not found in original model).")
|
||||
|
||||
num_layers = len(timesfm_model.decoder.layers)
|
||||
for i in range(num_layers):
|
||||
for old_template, new_template in TRANSFORMER_LAYER_MAPPING.items():
|
||||
old_key = old_template.format(i=i)
|
||||
new_key = new_template.format(i=i)
|
||||
|
||||
try:
|
||||
# Get tensor from original model
|
||||
old_attr = get_nested_attr(original_model, old_key)
|
||||
if "qkv_proj" in old_key:
|
||||
# Split the tensor into q, k, v projections
|
||||
q_proj, k_proj, v_proj = (
|
||||
old_attr[: tfm.hparams.model_dims, ...],
|
||||
old_attr[tfm.hparams.model_dims : tfm.hparams.model_dims * 2, ...],
|
||||
old_attr[tfm.hparams.model_dims * 2 :, ...],
|
||||
)
|
||||
# Get corresponding attribute in new model
|
||||
q_key = new_key.replace("qkv_proj", "q_proj")
|
||||
q_attr = get_nested_attr(timesfm_model, q_key)
|
||||
q_attr.data.copy_(q_proj.data) # Copy data
|
||||
k_key = new_key.replace("qkv_proj", "k_proj")
|
||||
k_attr = get_nested_attr(timesfm_model, k_key)
|
||||
k_attr.data.copy_(k_proj.data) # Copy data
|
||||
v_key = new_key.replace("qkv_proj", "v_proj")
|
||||
v_attr = get_nested_attr(timesfm_model, v_key)
|
||||
v_attr.data.copy_(v_proj.data) # Copy data
|
||||
else:
|
||||
# Get corresponding attribute in new model
|
||||
new_attr = get_nested_attr(timesfm_model, new_key)
|
||||
new_attr.data.copy_(old_attr.data) # Copy data
|
||||
except AttributeError:
|
||||
print(f"Skipping {old_key} (not found in original model).")
|
||||
|
||||
timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||
shutil.rmtree(tmp_model_path)
|
||||
|
||||
|
||||
def check_outputs(model_path, huggingface_repo_id):
|
||||
"""Compares outputs between original and converted models."""
|
||||
print("\nChecking model outputs...")
|
||||
|
||||
# Load original model
|
||||
tfm = timesfm.TimesFm(
|
||||
hparams=timesfm.TimesFmHparams(
|
||||
backend="cuda" if torch.cuda.is_available() else "cpu",
|
||||
per_core_batch_size=32,
|
||||
horizon_len=128,
|
||||
input_patch_len=32,
|
||||
output_patch_len=128,
|
||||
num_layers=50,
|
||||
model_dims=1280,
|
||||
use_positional_embedding=False,
|
||||
point_forecast_mode="mean",
|
||||
),
|
||||
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
|
||||
)
|
||||
|
||||
# Load converted model
|
||||
converted_model = TimesFmModelForPrediction.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa",
|
||||
).to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
converted_model.eval() # Set to evaluation mode
|
||||
|
||||
# Create test inputs
|
||||
forecast_input = [
|
||||
np.sin(np.linspace(0, 20, 100)),
|
||||
np.sin(np.linspace(0, 20, 200)),
|
||||
np.sin(np.linspace(0, 20, 400)),
|
||||
]
|
||||
frequency_input = [0, 1, 2]
|
||||
|
||||
# Get predictions from original model
|
||||
point_forecast_orig, quantile_forecast_orig = tfm.forecast(
|
||||
forecast_input,
|
||||
freq=frequency_input,
|
||||
)
|
||||
|
||||
# Convert inputs to sequence of tensors
|
||||
forecast_input_tensor = [
|
||||
torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for ts in forecast_input
|
||||
]
|
||||
frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Get predictions from converted model
|
||||
with torch.no_grad():
|
||||
outputs = converted_model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True)
|
||||
point_forecast_conv = outputs.mean_predictions.float().cpu().numpy()
|
||||
quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy()
|
||||
|
||||
# Compare outputs
|
||||
point_forecast_diff = np.abs(point_forecast_orig - point_forecast_conv)
|
||||
quantile_forecast_diff = np.abs(quantile_forecast_orig - quantile_forecast_conv)
|
||||
|
||||
max_point_diff = point_forecast_diff.max()
|
||||
mean_point_diff = point_forecast_diff.mean()
|
||||
max_quantile_diff = quantile_forecast_diff.max()
|
||||
mean_quantile_diff = quantile_forecast_diff.mean()
|
||||
|
||||
print("\nOutput comparison:")
|
||||
print(f"Point forecast - Max difference: {max_point_diff:.6f}")
|
||||
print(f"Point forecast - Mean difference: {mean_point_diff:.6f}")
|
||||
print(f"Quantile forecast - Max difference: {max_quantile_diff:.6f}")
|
||||
print(f"Quantile forecast - Mean difference: {mean_quantile_diff:.6f}")
|
||||
|
||||
# Define acceptable thresholds
|
||||
POINT_THRESHOLD = 1e-5
|
||||
QUANTILE_THRESHOLD = 1e-5
|
||||
|
||||
if max_point_diff > POINT_THRESHOLD or max_quantile_diff > QUANTILE_THRESHOLD:
|
||||
raise ValueError(
|
||||
f"Output mismatch detected!\n"
|
||||
f"Point forecast max diff: {max_point_diff} (threshold: {POINT_THRESHOLD})\n"
|
||||
f"Quantile forecast max diff: {max_quantile_diff} (threshold: {QUANTILE_THRESHOLD})"
|
||||
)
|
||||
|
||||
print("\n✓ All outputs match within acceptable tolerance!")
|
||||
|
||||
# Optional: Print shapes for verification
|
||||
print("\nOutput shapes:")
|
||||
print(f"Original point forecast: {point_forecast_orig.shape}")
|
||||
print(f"Converted point forecast: {point_forecast_conv.shape}")
|
||||
print(f"Original quantile forecast: {quantile_forecast_orig.shape}")
|
||||
print(f"Converted quantile forecast: {quantile_forecast_conv.shape}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
required=True,
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--huggingface_repo_id",
|
||||
type=str,
|
||||
default="google/timesfm-2.0-500m-pytorch",
|
||||
help="The Hugging Face repository ID to use for the model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# if the saved model file exists, skip the conversion
|
||||
if os.path.exists(
|
||||
os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "pytorch_model.bin")
|
||||
):
|
||||
print(f"Model already exists in {args.output_dir}, skipping conversion.")
|
||||
else:
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
safe_serialization=args.safe_serialization,
|
||||
huggingface_repo_id=args.huggingface_repo_id,
|
||||
)
|
||||
check_outputs(args.output_dir, args.huggingface_repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
904
src/transformers/models/timesfm/modeling_timesfm.py
Normal file
904
src/transformers/models/timesfm/modeling_timesfm.py
Normal file
@ -0,0 +1,904 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/timesfm/modular_timesfm.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_timesfm.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google LLC and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_timesfm import TimesFmConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch"
|
||||
_CONFIG_FOR_DOC = "TimesFmConfig"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimesFmOutput(BaseModelOutput):
|
||||
"""
|
||||
Args:
|
||||
loc (`torch.Tensor` of shape `(batch_size, )`):
|
||||
The mean of the time series inputs.
|
||||
scale (`torch.Tensor` of shape `(batch_size,)`):
|
||||
The scale of the time series inputs.
|
||||
"""
|
||||
|
||||
loc: Optional[torch.Tensor] = None
|
||||
scale: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimesFmOutputForPrediction(BaseModelOutput):
|
||||
"""
|
||||
Args:
|
||||
mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
The mean predictions of the time series.
|
||||
full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
The full predictions of the time series including the mean and the quantiles.
|
||||
loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
|
||||
The loss of the TimesFM model.
|
||||
"""
|
||||
|
||||
mean_predictions: Optional[torch.Tensor] = None
|
||||
full_predictions: Optional[torch.Tensor] = None
|
||||
loss: Optional[Union[torch.Tensor, float]] = None
|
||||
|
||||
|
||||
class TimesFmMLP(nn.Module):
|
||||
"""Pax MLP in pytorch."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
|
||||
|
||||
def forward(self, x, paddings=None):
|
||||
gate_inp = self.layer_norm(x)
|
||||
gate = self.gate_proj(gate_inp)
|
||||
gate = F.relu(gate)
|
||||
outputs = self.down_proj(gate)
|
||||
if paddings is not None:
|
||||
outputs = outputs * (1.0 - paddings[:, :, None])
|
||||
return outputs + x
|
||||
|
||||
|
||||
class TimesFmResidualBlock(nn.Module):
|
||||
"""TimesFM residual block."""
|
||||
|
||||
def __init__(self, input_dims, hidden_dims, output_dims):
|
||||
super().__init__()
|
||||
self.input_dims = input_dims
|
||||
self.hidden_dims = hidden_dims
|
||||
self.output_dims = output_dims
|
||||
|
||||
self.input_layer = nn.Linear(input_dims, hidden_dims)
|
||||
self.activation = nn.SiLU()
|
||||
self.output_layer = nn.Linear(hidden_dims, output_dims)
|
||||
self.residual_layer = nn.Linear(input_dims, output_dims)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.input_layer(x)
|
||||
hidden = self.activation(hidden)
|
||||
output = self.output_layer(hidden)
|
||||
residual = self.residual_layer(x)
|
||||
return output + residual
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class TimesFmRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
TimesFmRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class TimesFmPositionalEmbedding(nn.Module):
|
||||
"""Generates position embedding for a given 1-d sequence."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__()
|
||||
min_timescale = config.min_timescale
|
||||
max_timescale = config.max_timescale
|
||||
self.embedding_dims = config.hidden_size
|
||||
|
||||
num_timescales = self.embedding_dims // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
|
||||
self.register_buffer(
|
||||
"inv_timescales",
|
||||
min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
|
||||
)
|
||||
|
||||
def forward(self, seq_length=None, position=None):
|
||||
"""Generates a Tensor of sinusoids with different frequencies.
|
||||
|
||||
Args:
|
||||
seq_length: an optional Python int defining the output sequence length.
|
||||
if the `position` argument is specified.
|
||||
position: [B, seq_length], optional position for each token in the
|
||||
sequence, only required when the sequence is packed.
|
||||
|
||||
Returns:
|
||||
[B, seqlen, D] if `position` is specified, else [1, seqlen, D]
|
||||
"""
|
||||
if position is None and seq_length is None:
|
||||
raise ValueError("Either position or seq_length must be provided")
|
||||
|
||||
if position is None:
|
||||
# [1, seqlen]
|
||||
position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
|
||||
elif position.ndim != 2:
|
||||
raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")
|
||||
|
||||
scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
||||
|
||||
# Padding to ensure correct embedding dimension
|
||||
signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
|
||||
return signal
|
||||
|
||||
|
||||
def simple_eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class TimesFmAttention(nn.Module):
|
||||
"""Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = config.head_dim
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_heads * self.head_dim
|
||||
self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
|
||||
|
||||
def _scale_query(self, query: torch.Tensor) -> torch.Tensor:
|
||||
scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
|
||||
return query * scale[None, None, None, :]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
query_states = self._scale_query(query_states)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = simple_eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=1.0,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class TimesFmDecoderLayer(nn.Module):
|
||||
"""Transformer layer."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = TimesFmAttention(config, layer_idx=layer_idx)
|
||||
self.mlp = TimesFmMLP(config)
|
||||
self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
paddings: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
) -> tuple[Optional[torch.Tensor], torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, scores = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# MLP
|
||||
hidden_states = self.mlp(hidden_states, paddings=paddings)
|
||||
|
||||
return scores, hidden_states
|
||||
|
||||
|
||||
TIMESFM_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`TimesFmConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
|
||||
TIMESFM_START_DOCSTRING,
|
||||
)
|
||||
class TimesFmPreTrainedModel(PreTrainedModel):
|
||||
"""handles the loading for all models."""
|
||||
|
||||
config_class = TimesFmConfig
|
||||
base_model_prefix = "timesfm"
|
||||
_no_split_modules = ["TimesFmDecoderLayer"]
|
||||
main_input_name = "past_values"
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
|
||||
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
elif isinstance(module, TimesFmRMSNorm):
|
||||
nn.init.zeros_(module.weight)
|
||||
|
||||
elif isinstance(module, TimesFmAttention):
|
||||
# Initialize scaling parameter
|
||||
nn.init.ones_(module.scaling)
|
||||
|
||||
|
||||
TIMESFM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
past_values: list of time series forecast contexts. Each context time series
|
||||
can be a torch Tensor of potentially different context lengths.
|
||||
freq: frequency of each context time series in the inputs. 0 for high frequency
|
||||
(default), 1 for medium, and 2 for low.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail. tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
|
||||
TIMESFM_START_DOCSTRING,
|
||||
)
|
||||
class TimesFmModel(TimesFmPreTrainedModel):
|
||||
"""Patched time-series decoder without any specific output layer."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.input_ff_layer = TimesFmResidualBlock(
|
||||
input_dims=2 * config.patch_length,
|
||||
output_dims=config.hidden_size,
|
||||
hidden_dims=config.intermediate_size,
|
||||
)
|
||||
self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size)
|
||||
self.layers = nn.ModuleList(
|
||||
[TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
if self.config.use_positional_embedding:
|
||||
self.position_emb = TimesFmPositionalEmbedding(config=config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _forward_transform(
|
||||
self, inputs: torch.Tensor, patched_pads: torch.Tensor
|
||||
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Input is of shape [B, N, P]."""
|
||||
mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
|
||||
sigma = torch.where(
|
||||
sigma < self.config.tolerance,
|
||||
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
|
||||
sigma,
|
||||
)
|
||||
|
||||
# Normalize each patch
|
||||
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
|
||||
outputs = torch.where(
|
||||
torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
|
||||
torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
|
||||
outputs,
|
||||
)
|
||||
return outputs, (mu, sigma)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
past_values: torch.Tensor,
|
||||
past_values_padding: torch.LongTensor,
|
||||
freq: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
) -> TimesFmOutput:
|
||||
"""
|
||||
past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The padding indicator of the time series.
|
||||
"""
|
||||
# Reshape into patches (using view for efficiency)
|
||||
bsize = past_values.shape[0]
|
||||
patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
|
||||
patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)
|
||||
|
||||
patched_inputs = torch.where(
|
||||
torch.abs(patched_pads - 1.0) < self.config.tolerance,
|
||||
torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),
|
||||
patched_inputs,
|
||||
)
|
||||
patched_pads = torch.where(
|
||||
torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
|
||||
torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
|
||||
patched_pads,
|
||||
)
|
||||
patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
|
||||
|
||||
# B x N x D
|
||||
patched_inputs = patched_inputs * (1.0 - patched_pads)
|
||||
concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
|
||||
model_input = self.input_ff_layer(concat_inputs)
|
||||
|
||||
# A patch should not be padded even if there is at least one zero.
|
||||
patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result
|
||||
if self.config.use_positional_embedding:
|
||||
pos_emb = self.position_emb(model_input.shape[1])
|
||||
pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
|
||||
pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb)
|
||||
model_input += pos_emb
|
||||
|
||||
f_emb = self.freq_emb(freq) # B x 1 x D
|
||||
model_input += f_emb
|
||||
|
||||
# Convert paddings to attention mask and combine with causal mask
|
||||
hidden_states = model_input
|
||||
attention_mask = self._prepare_4d_attention_mask(
|
||||
attention_mask=patched_padding,
|
||||
sequence_length=hidden_states.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
all_attentions = []
|
||||
all_hidden_states = []
|
||||
|
||||
for layer in self.layers[: self.config.num_hidden_layers]:
|
||||
scores, hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
paddings=patched_padding,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
if output_attentions:
|
||||
all_attentions.append(scores)
|
||||
if output_hidden_states:
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = [model_input] + all_hidden_states
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
return TimesFmOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions if output_attentions else None,
|
||||
loc=stats[0],
|
||||
scale=stats[1],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
sequence_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
is_causal: bool = True,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Creates 4D attention mask and combines causal and padding masks if needed.
|
||||
|
||||
Args:
|
||||
attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
|
||||
sequence_length: Length of the sequence
|
||||
dtype: Data type of the mask
|
||||
device: Device of the mask
|
||||
is_causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
4D attention mask of shape (batch_size, 1, seq_length, seq_length)
|
||||
"""
|
||||
# Get minimum value for the dtype
|
||||
min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
|
||||
|
||||
# Handle padding mask
|
||||
if attention_mask is not None:
|
||||
# Convert 2D padding mask to 4D attention mask
|
||||
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
|
||||
attention_mask = attention_mask * min_value
|
||||
|
||||
# Create causal mask if needed
|
||||
if is_causal:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value,
|
||||
diagonal=1,
|
||||
)
|
||||
causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length)
|
||||
|
||||
# Combine with padding mask if it exists
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.minimum(attention_mask, causal_mask)
|
||||
else:
|
||||
attention_mask = causal_mask
|
||||
|
||||
return attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculates mean and standard deviation of `inputs` across axis 1.
|
||||
|
||||
It excludes values where `padding` is 1.
|
||||
|
||||
Args:
|
||||
inputs: A PyTorch tensor of shape [b, n, p].
|
||||
padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
|
||||
|
||||
Returns:
|
||||
A tuple containing the mean and standard deviation.
|
||||
We return the statistics of the first patch with more than three non-padded values.
|
||||
"""
|
||||
|
||||
# Selecting the first patch with more than 3 unpadded values.
|
||||
def _get_patch_index(arr: torch.Tensor):
|
||||
indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
|
||||
row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
|
||||
return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
|
||||
|
||||
pad_sum = torch.sum(1 - padding, dim=2)
|
||||
patch_indices = _get_patch_index(pad_sum)
|
||||
bidxs = torch.arange(inputs.shape[0])
|
||||
|
||||
arr = inputs[bidxs, patch_indices, :]
|
||||
pad = padding[bidxs, patch_indices, :]
|
||||
|
||||
# Create a mask where padding is 0
|
||||
mask = 1 - pad
|
||||
|
||||
# Calculate the number of valid elements
|
||||
num_valid_elements = torch.sum(mask, dim=1)
|
||||
num_valid_elements = torch.where(
|
||||
num_valid_elements == 0,
|
||||
torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
|
||||
num_valid_elements,
|
||||
)
|
||||
|
||||
# Calculate the masked sum and squared sum
|
||||
masked_sum = torch.sum(arr * mask, dim=1)
|
||||
masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1)
|
||||
|
||||
# Calculate the masked mean and standard deviation
|
||||
masked_mean = masked_sum / num_valid_elements
|
||||
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
|
||||
masked_var = torch.where(
|
||||
masked_var < 0.0,
|
||||
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
|
||||
masked_var,
|
||||
)
|
||||
masked_std = torch.sqrt(masked_var)
|
||||
|
||||
return masked_mean, masked_std
|
||||
|
||||
@staticmethod
|
||||
def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
|
||||
"""Shifts rows of seq based on the first 0 in each row of the mask.
|
||||
|
||||
Args:
|
||||
mask: mask tensor of shape [B, N]
|
||||
seq: seq tensor of shape [B, N, P]
|
||||
|
||||
Returns:
|
||||
The shifted sequence.
|
||||
"""
|
||||
batch_size, num_seq, feature_dim = seq.shape
|
||||
|
||||
new_mask: torch.BoolTensor = mask == 0
|
||||
|
||||
# Use argmax to find the first True value in each row
|
||||
indices = new_mask.to(torch.int32).argmax(dim=1)
|
||||
|
||||
# Handle rows with all zeros
|
||||
indices[~new_mask.any(dim=1)] = -1
|
||||
|
||||
# Create index ranges for each sequence in the batch
|
||||
idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim)
|
||||
|
||||
# Calculate shifted indices for each element in each sequence
|
||||
shifted_idx = (idx_range - indices[:, None, None]) % num_seq
|
||||
|
||||
# Gather values from seq using shifted indices
|
||||
shifted_seq = seq.gather(1, shifted_idx)
|
||||
|
||||
return shifted_seq
|
||||
|
||||
|
||||
class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
||||
"""TimesFM model for quantile and mean prediction."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.context_len = config.context_length
|
||||
self.horizon_len = config.horizon_length
|
||||
|
||||
self.decoder = TimesFmModel(config)
|
||||
|
||||
# quantile and mean output
|
||||
self.horizon_ff_layer = TimesFmResidualBlock(
|
||||
input_dims=config.hidden_size,
|
||||
output_dims=config.horizon_length * (1 + len(config.quantiles)),
|
||||
hidden_dims=config.intermediate_size,
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _preprocess(
|
||||
self, inputs: Sequence[torch.Tensor], freq: Sequence[int]
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Formats and pads raw inputs to feed into the model.
|
||||
|
||||
This function both pads each time series to match the context length, and
|
||||
pads the inputs to meet the SPMD shape requirement.
|
||||
|
||||
Args:
|
||||
inputs: A list of 1d Tensors. Each Tensor is the context time series of
|
||||
a single forecast task.
|
||||
freq: list of frequencies
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- the padded input time series to meet the model required context.
|
||||
- the padding indicator.
|
||||
- the number of padded examples for SPMD so that each core has the same
|
||||
number (a multiple of `batch_size`) of examples.
|
||||
"""
|
||||
input_ts, input_padding, inp_freq = [], [], []
|
||||
|
||||
for i, ts in enumerate(inputs):
|
||||
input_len = ts.shape[0]
|
||||
padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
|
||||
if input_len < self.context_len:
|
||||
num_front_pad = self.context_len - input_len
|
||||
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
|
||||
padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
|
||||
elif input_len > self.context_len:
|
||||
ts = ts[-self.context_len :]
|
||||
padding = padding[-(self.context_len + self.horizon_len) :]
|
||||
|
||||
input_ts.append(ts)
|
||||
input_padding.append(padding)
|
||||
inp_freq.append(freq[i])
|
||||
|
||||
return (
|
||||
torch.stack(input_ts, dim=0),
|
||||
torch.stack(input_padding, dim=0),
|
||||
torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
|
||||
)
|
||||
|
||||
def _postprocess_output(
|
||||
self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""Postprocess output of stacked transformer."""
|
||||
|
||||
# B x N x (H.Q)
|
||||
output_ts = self.horizon_ff_layer(model_output)
|
||||
|
||||
# Reshape using view
|
||||
b, n, _ = output_ts.shape
|
||||
output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)
|
||||
|
||||
mu, sigma = stats
|
||||
return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
|
||||
|
||||
def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
losses = []
|
||||
for i, q in enumerate(self.config.quantiles):
|
||||
errors = targets - predictions[..., i]
|
||||
loss = torch.max((q - 1) * errors, q * errors)
|
||||
losses.append(loss.mean())
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TimesFmOutputForPrediction,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
past_values: Sequence[torch.Tensor],
|
||||
freq: Optional[Sequence[Union[torch.Tensor, int]]] = None,
|
||||
window_size: Optional[int] = None,
|
||||
future_values: Optional[torch.Tensor] = None,
|
||||
forecast_context_len: Optional[int] = None,
|
||||
return_forecast_on_context: bool = False,
|
||||
truncate_negative: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> TimesFmOutputForPrediction:
|
||||
r"""
|
||||
window_size (`int`, *optional*):
|
||||
Window size of trend + residual decomposition. If None then we do not do decomposition.
|
||||
future_values (`torch.Tensor`, *optional*):
|
||||
Optional future time series values to be used for loss computation.
|
||||
forecast_context_len (`int`, *optional*):
|
||||
Optional max context length.
|
||||
return_forecast_on_context (`bool`, *optional*):
|
||||
True to return the forecast on the context when available, i.e. after the first input patch.
|
||||
truncate_negative (`bool`, *optional*):
|
||||
Truncate to only non-negative values if any of the contexts have non-negative values,
|
||||
otherwise do nothing.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether to output the attentions.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether to output the hidden states.
|
||||
|
||||
Returns:
|
||||
A TimesFmOutputForPrediction object or a tuple containing:
|
||||
- the mean forecast of size (# past_values, # forecast horizon),
|
||||
- the full forecast (mean + quantiles) of size
|
||||
(# past_values, # forecast horizon, 1 + # quantiles).
|
||||
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
|
||||
"""
|
||||
if forecast_context_len is None:
|
||||
fcontext_len = self.context_len
|
||||
else:
|
||||
fcontext_len = forecast_context_len
|
||||
|
||||
# Get device from first input tensor
|
||||
device = past_values[0].device
|
||||
|
||||
# Truncate inputs to forecast_context_len
|
||||
inputs = [ts[-fcontext_len:] for ts in past_values]
|
||||
inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
|
||||
|
||||
if window_size is not None:
|
||||
new_inputs = []
|
||||
new_freqs = []
|
||||
for i, ts in enumerate(inputs):
|
||||
new_inputs.extend(self._timesfm_moving_average(ts, window_size))
|
||||
if freq is not None:
|
||||
new_freqs.extend([freq[i]] * 2)
|
||||
inputs = new_inputs
|
||||
if freq is not None:
|
||||
freq = new_freqs
|
||||
|
||||
if freq is None:
|
||||
logger.info("No frequency provided via `freq`. Default to high (0).")
|
||||
freq = [0] * len(inputs)
|
||||
|
||||
if output_attentions is None:
|
||||
output_attentions = self.config.output_attentions
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.config.output_hidden_states
|
||||
|
||||
input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
|
||||
# Move tensors to the same device as input
|
||||
input_ts = input_ts.to(device)
|
||||
input_padding = input_padding.to(device)
|
||||
inp_freq = inp_freq.to(device)
|
||||
|
||||
final_out = input_ts
|
||||
context_len = final_out.shape[1]
|
||||
full_outputs = []
|
||||
|
||||
if input_padding.shape[1] != final_out.shape[1] + self.horizon_len:
|
||||
raise ValueError(
|
||||
"Length of paddings must match length of input + horizon_len:"
|
||||
f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}"
|
||||
)
|
||||
output_patch_len = self.config.horizon_length
|
||||
|
||||
num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len
|
||||
for step_index in range(num_decode_patches):
|
||||
current_padding = input_padding[:, 0 : final_out.shape[1]]
|
||||
input_ts = final_out[:, -fcontext_len:]
|
||||
input_padding = current_padding[:, -fcontext_len:]
|
||||
decoder_output = self.decoder(
|
||||
past_values=input_ts,
|
||||
past_values_padding=input_padding,
|
||||
freq=inp_freq,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
fprop_outputs = self._postprocess_output(
|
||||
decoder_output.last_hidden_state,
|
||||
(decoder_output.loc, decoder_output.scale),
|
||||
)
|
||||
|
||||
if return_forecast_on_context and step_index == 0:
|
||||
# For the first decodings step, collect the model forecast on the
|
||||
# context except the unavailable first input batch forecast.
|
||||
new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :]
|
||||
# We have to use reshape and not view for non-contiguous memory
|
||||
new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3))
|
||||
|
||||
full_outputs.append(new_full_ts)
|
||||
|
||||
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
|
||||
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
|
||||
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
|
||||
# (full batch, last patch, output_patch_len, all output indices)
|
||||
full_outputs.append(new_full_ts)
|
||||
final_out = torch.concatenate([final_out, new_ts], axis=-1)
|
||||
|
||||
if return_forecast_on_context:
|
||||
# `full_outputs` indexing starts at after the first input patch.
|
||||
full_outputs = torch.concatenate(full_outputs, axis=1)[
|
||||
:, : (context_len - self.config.patch_length + self.horizon_len), :
|
||||
]
|
||||
else:
|
||||
# `full_outputs` indexing starts at the forecast horizon.
|
||||
full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :]
|
||||
|
||||
mean_outputs = full_outputs[:, :, 0]
|
||||
if window_size is not None:
|
||||
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
|
||||
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
|
||||
if inp_min >= 0 and truncate_negative:
|
||||
mean_outputs = torch.maximum(mean_outputs, 0.0)
|
||||
full_outputs = torch.maximum(full_outputs, 0.0)
|
||||
|
||||
loss = None
|
||||
if future_values is not None:
|
||||
mse_loss = F.mse_loss(mean_outputs, future_values)
|
||||
quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values)
|
||||
loss = mse_loss + quantile_loss
|
||||
|
||||
return TimesFmOutputForPrediction(
|
||||
last_hidden_state=decoder_output.last_hidden_state,
|
||||
attentions=decoder_output.attentions if output_attentions else None,
|
||||
hidden_states=decoder_output.hidden_states if output_hidden_states else None,
|
||||
mean_predictions=mean_outputs,
|
||||
full_predictions=full_outputs,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
|
||||
"""Calculates the moving average using PyTorch's convolution function."""
|
||||
# Pad with zeros to handle initial window positions
|
||||
arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
|
||||
# Create a convolution kernel
|
||||
kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
|
||||
# Apply convolution to calculate the moving average
|
||||
smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
|
||||
return [smoothed_arr, arr - smoothed_arr]
|
||||
|
||||
|
||||
__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]
|
860
src/transformers/models/timesfm/modular_timesfm.py
Normal file
860
src/transformers/models/timesfm/modular_timesfm.py
Normal file
@ -0,0 +1,860 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google LLC and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch TimesFM model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ..llama.modeling_llama import LlamaRMSNorm
|
||||
from ..phi4_multimodal.modeling_phi4_multimodal import simple_eager_attention_forward
|
||||
from .configuration_timesfm import TimesFmConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch"
|
||||
_CONFIG_FOR_DOC = "TimesFmConfig"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimesFmOutput(BaseModelOutput):
|
||||
"""
|
||||
Args:
|
||||
loc (`torch.Tensor` of shape `(batch_size, )`):
|
||||
The mean of the time series inputs.
|
||||
scale (`torch.Tensor` of shape `(batch_size,)`):
|
||||
The scale of the time series inputs.
|
||||
"""
|
||||
|
||||
loc: Optional[torch.Tensor] = None
|
||||
scale: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimesFmOutputForPrediction(BaseModelOutput):
|
||||
"""
|
||||
Args:
|
||||
mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
The mean predictions of the time series.
|
||||
full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
The full predictions of the time series including the mean and the quantiles.
|
||||
loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
|
||||
The loss of the TimesFM model.
|
||||
"""
|
||||
|
||||
mean_predictions: Optional[torch.Tensor] = None
|
||||
full_predictions: Optional[torch.Tensor] = None
|
||||
loss: Optional[Union[torch.Tensor, float]] = None
|
||||
|
||||
|
||||
class TimesFmMLP(nn.Module):
|
||||
"""Pax MLP in pytorch."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
|
||||
|
||||
def forward(self, x, paddings=None):
|
||||
gate_inp = self.layer_norm(x)
|
||||
gate = self.gate_proj(gate_inp)
|
||||
gate = F.relu(gate)
|
||||
outputs = self.down_proj(gate)
|
||||
if paddings is not None:
|
||||
outputs = outputs * (1.0 - paddings[:, :, None])
|
||||
return outputs + x
|
||||
|
||||
|
||||
class TimesFmResidualBlock(nn.Module):
|
||||
"""TimesFM residual block."""
|
||||
|
||||
def __init__(self, input_dims, hidden_dims, output_dims):
|
||||
super().__init__()
|
||||
self.input_dims = input_dims
|
||||
self.hidden_dims = hidden_dims
|
||||
self.output_dims = output_dims
|
||||
|
||||
self.input_layer = nn.Linear(input_dims, hidden_dims)
|
||||
self.activation = nn.SiLU()
|
||||
self.output_layer = nn.Linear(hidden_dims, output_dims)
|
||||
self.residual_layer = nn.Linear(input_dims, output_dims)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.input_layer(x)
|
||||
hidden = self.activation(hidden)
|
||||
output = self.output_layer(hidden)
|
||||
residual = self.residual_layer(x)
|
||||
return output + residual
|
||||
|
||||
|
||||
class TimesFmRMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class TimesFmPositionalEmbedding(nn.Module):
|
||||
"""Generates position embedding for a given 1-d sequence."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__()
|
||||
min_timescale = config.min_timescale
|
||||
max_timescale = config.max_timescale
|
||||
self.embedding_dims = config.hidden_size
|
||||
|
||||
num_timescales = self.embedding_dims // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
|
||||
self.register_buffer(
|
||||
"inv_timescales",
|
||||
min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
|
||||
)
|
||||
|
||||
def forward(self, seq_length=None, position=None):
|
||||
"""Generates a Tensor of sinusoids with different frequencies.
|
||||
|
||||
Args:
|
||||
seq_length: an optional Python int defining the output sequence length.
|
||||
if the `position` argument is specified.
|
||||
position: [B, seq_length], optional position for each token in the
|
||||
sequence, only required when the sequence is packed.
|
||||
|
||||
Returns:
|
||||
[B, seqlen, D] if `position` is specified, else [1, seqlen, D]
|
||||
"""
|
||||
if position is None and seq_length is None:
|
||||
raise ValueError("Either position or seq_length must be provided")
|
||||
|
||||
if position is None:
|
||||
# [1, seqlen]
|
||||
position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
|
||||
elif position.ndim != 2:
|
||||
raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")
|
||||
|
||||
scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
||||
|
||||
# Padding to ensure correct embedding dimension
|
||||
signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
|
||||
return signal
|
||||
|
||||
|
||||
class TimesFmAttention(nn.Module):
|
||||
"""Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = config.head_dim
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_heads * self.head_dim
|
||||
self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
|
||||
|
||||
def _scale_query(self, query: torch.Tensor) -> torch.Tensor:
|
||||
scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
|
||||
return query * scale[None, None, None, :]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
query_states = self._scale_query(query_states)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = simple_eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=1.0,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class TimesFmDecoderLayer(nn.Module):
|
||||
"""Transformer layer."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = TimesFmAttention(config, layer_idx=layer_idx)
|
||||
self.mlp = TimesFmMLP(config)
|
||||
self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
paddings: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
) -> tuple[Optional[torch.Tensor], torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, scores = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# MLP
|
||||
hidden_states = self.mlp(hidden_states, paddings=paddings)
|
||||
|
||||
return scores, hidden_states
|
||||
|
||||
|
||||
TIMESFM_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`TimesFmConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
|
||||
TIMESFM_START_DOCSTRING,
|
||||
)
|
||||
class TimesFmPreTrainedModel(PreTrainedModel):
|
||||
"""handles the loading for all models."""
|
||||
|
||||
config_class = TimesFmConfig
|
||||
base_model_prefix = "timesfm"
|
||||
_no_split_modules = ["TimesFmDecoderLayer"]
|
||||
main_input_name = "past_values"
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
|
||||
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
elif isinstance(module, TimesFmRMSNorm):
|
||||
nn.init.zeros_(module.weight)
|
||||
|
||||
elif isinstance(module, TimesFmAttention):
|
||||
# Initialize scaling parameter
|
||||
nn.init.ones_(module.scaling)
|
||||
|
||||
|
||||
TIMESFM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
past_values: list of time series forecast contexts. Each context time series
|
||||
can be a torch Tensor of potentially different context lengths.
|
||||
freq: frequency of each context time series in the inputs. 0 for high frequency
|
||||
(default), 1 for medium, and 2 for low.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail. tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
|
||||
TIMESFM_START_DOCSTRING,
|
||||
)
|
||||
class TimesFmModel(TimesFmPreTrainedModel):
|
||||
"""Patched time-series decoder without any specific output layer."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.input_ff_layer = TimesFmResidualBlock(
|
||||
input_dims=2 * config.patch_length,
|
||||
output_dims=config.hidden_size,
|
||||
hidden_dims=config.intermediate_size,
|
||||
)
|
||||
self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size)
|
||||
self.layers = nn.ModuleList(
|
||||
[TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
if self.config.use_positional_embedding:
|
||||
self.position_emb = TimesFmPositionalEmbedding(config=config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _forward_transform(
|
||||
self, inputs: torch.Tensor, patched_pads: torch.Tensor
|
||||
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Input is of shape [B, N, P]."""
|
||||
mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
|
||||
sigma = torch.where(
|
||||
sigma < self.config.tolerance,
|
||||
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
|
||||
sigma,
|
||||
)
|
||||
|
||||
# Normalize each patch
|
||||
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
|
||||
outputs = torch.where(
|
||||
torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
|
||||
torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
|
||||
outputs,
|
||||
)
|
||||
return outputs, (mu, sigma)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
past_values: torch.Tensor,
|
||||
past_values_padding: torch.LongTensor,
|
||||
freq: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
) -> TimesFmOutput:
|
||||
"""
|
||||
past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The padding indicator of the time series.
|
||||
"""
|
||||
# Reshape into patches (using view for efficiency)
|
||||
bsize = past_values.shape[0]
|
||||
patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
|
||||
patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)
|
||||
|
||||
patched_inputs = torch.where(
|
||||
torch.abs(patched_pads - 1.0) < self.config.tolerance,
|
||||
torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),
|
||||
patched_inputs,
|
||||
)
|
||||
patched_pads = torch.where(
|
||||
torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
|
||||
torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
|
||||
patched_pads,
|
||||
)
|
||||
patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
|
||||
|
||||
# B x N x D
|
||||
patched_inputs = patched_inputs * (1.0 - patched_pads)
|
||||
concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
|
||||
model_input = self.input_ff_layer(concat_inputs)
|
||||
|
||||
# A patch should not be padded even if there is at least one zero.
|
||||
patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result
|
||||
if self.config.use_positional_embedding:
|
||||
pos_emb = self.position_emb(model_input.shape[1])
|
||||
pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
|
||||
pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb)
|
||||
model_input += pos_emb
|
||||
|
||||
f_emb = self.freq_emb(freq) # B x 1 x D
|
||||
model_input += f_emb
|
||||
|
||||
# Convert paddings to attention mask and combine with causal mask
|
||||
hidden_states = model_input
|
||||
attention_mask = self._prepare_4d_attention_mask(
|
||||
attention_mask=patched_padding,
|
||||
sequence_length=hidden_states.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
all_attentions = []
|
||||
all_hidden_states = []
|
||||
|
||||
for layer in self.layers[: self.config.num_hidden_layers]:
|
||||
scores, hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
paddings=patched_padding,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
if output_attentions:
|
||||
all_attentions.append(scores)
|
||||
if output_hidden_states:
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = [model_input] + all_hidden_states
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
return TimesFmOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions if output_attentions else None,
|
||||
loc=stats[0],
|
||||
scale=stats[1],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
sequence_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
is_causal: bool = True,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Creates 4D attention mask and combines causal and padding masks if needed.
|
||||
|
||||
Args:
|
||||
attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
|
||||
sequence_length: Length of the sequence
|
||||
dtype: Data type of the mask
|
||||
device: Device of the mask
|
||||
is_causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
4D attention mask of shape (batch_size, 1, seq_length, seq_length)
|
||||
"""
|
||||
# Get minimum value for the dtype
|
||||
min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
|
||||
|
||||
# Handle padding mask
|
||||
if attention_mask is not None:
|
||||
# Convert 2D padding mask to 4D attention mask
|
||||
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
|
||||
attention_mask = attention_mask * min_value
|
||||
|
||||
# Create causal mask if needed
|
||||
if is_causal:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value,
|
||||
diagonal=1,
|
||||
)
|
||||
causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length)
|
||||
|
||||
# Combine with padding mask if it exists
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.minimum(attention_mask, causal_mask)
|
||||
else:
|
||||
attention_mask = causal_mask
|
||||
|
||||
return attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculates mean and standard deviation of `inputs` across axis 1.
|
||||
|
||||
It excludes values where `padding` is 1.
|
||||
|
||||
Args:
|
||||
inputs: A PyTorch tensor of shape [b, n, p].
|
||||
padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
|
||||
|
||||
Returns:
|
||||
A tuple containing the mean and standard deviation.
|
||||
We return the statistics of the first patch with more than three non-padded values.
|
||||
"""
|
||||
|
||||
# Selecting the first patch with more than 3 unpadded values.
|
||||
def _get_patch_index(arr: torch.Tensor):
|
||||
indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
|
||||
row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
|
||||
return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
|
||||
|
||||
pad_sum = torch.sum(1 - padding, dim=2)
|
||||
patch_indices = _get_patch_index(pad_sum)
|
||||
bidxs = torch.arange(inputs.shape[0])
|
||||
|
||||
arr = inputs[bidxs, patch_indices, :]
|
||||
pad = padding[bidxs, patch_indices, :]
|
||||
|
||||
# Create a mask where padding is 0
|
||||
mask = 1 - pad
|
||||
|
||||
# Calculate the number of valid elements
|
||||
num_valid_elements = torch.sum(mask, dim=1)
|
||||
num_valid_elements = torch.where(
|
||||
num_valid_elements == 0,
|
||||
torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
|
||||
num_valid_elements,
|
||||
)
|
||||
|
||||
# Calculate the masked sum and squared sum
|
||||
masked_sum = torch.sum(arr * mask, dim=1)
|
||||
masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1)
|
||||
|
||||
# Calculate the masked mean and standard deviation
|
||||
masked_mean = masked_sum / num_valid_elements
|
||||
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
|
||||
masked_var = torch.where(
|
||||
masked_var < 0.0,
|
||||
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
|
||||
masked_var,
|
||||
)
|
||||
masked_std = torch.sqrt(masked_var)
|
||||
|
||||
return masked_mean, masked_std
|
||||
|
||||
@staticmethod
|
||||
def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
|
||||
"""Shifts rows of seq based on the first 0 in each row of the mask.
|
||||
|
||||
Args:
|
||||
mask: mask tensor of shape [B, N]
|
||||
seq: seq tensor of shape [B, N, P]
|
||||
|
||||
Returns:
|
||||
The shifted sequence.
|
||||
"""
|
||||
batch_size, num_seq, feature_dim = seq.shape
|
||||
|
||||
new_mask: torch.BoolTensor = mask == 0
|
||||
|
||||
# Use argmax to find the first True value in each row
|
||||
indices = new_mask.to(torch.int32).argmax(dim=1)
|
||||
|
||||
# Handle rows with all zeros
|
||||
indices[~new_mask.any(dim=1)] = -1
|
||||
|
||||
# Create index ranges for each sequence in the batch
|
||||
idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim)
|
||||
|
||||
# Calculate shifted indices for each element in each sequence
|
||||
shifted_idx = (idx_range - indices[:, None, None]) % num_seq
|
||||
|
||||
# Gather values from seq using shifted indices
|
||||
shifted_seq = seq.gather(1, shifted_idx)
|
||||
|
||||
return shifted_seq
|
||||
|
||||
|
||||
class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
||||
"""TimesFM model for quantile and mean prediction."""
|
||||
|
||||
def __init__(self, config: TimesFmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.context_len = config.context_length
|
||||
self.horizon_len = config.horizon_length
|
||||
|
||||
self.decoder = TimesFmModel(config)
|
||||
|
||||
# quantile and mean output
|
||||
self.horizon_ff_layer = TimesFmResidualBlock(
|
||||
input_dims=config.hidden_size,
|
||||
output_dims=config.horizon_length * (1 + len(config.quantiles)),
|
||||
hidden_dims=config.intermediate_size,
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _preprocess(
|
||||
self, inputs: Sequence[torch.Tensor], freq: Sequence[int]
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Formats and pads raw inputs to feed into the model.
|
||||
|
||||
This function both pads each time series to match the context length, and
|
||||
pads the inputs to meet the SPMD shape requirement.
|
||||
|
||||
Args:
|
||||
inputs: A list of 1d Tensors. Each Tensor is the context time series of
|
||||
a single forecast task.
|
||||
freq: list of frequencies
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- the padded input time series to meet the model required context.
|
||||
- the padding indicator.
|
||||
- the number of padded examples for SPMD so that each core has the same
|
||||
number (a multiple of `batch_size`) of examples.
|
||||
"""
|
||||
input_ts, input_padding, inp_freq = [], [], []
|
||||
|
||||
for i, ts in enumerate(inputs):
|
||||
input_len = ts.shape[0]
|
||||
padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
|
||||
if input_len < self.context_len:
|
||||
num_front_pad = self.context_len - input_len
|
||||
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
|
||||
padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
|
||||
elif input_len > self.context_len:
|
||||
ts = ts[-self.context_len :]
|
||||
padding = padding[-(self.context_len + self.horizon_len) :]
|
||||
|
||||
input_ts.append(ts)
|
||||
input_padding.append(padding)
|
||||
inp_freq.append(freq[i])
|
||||
|
||||
return (
|
||||
torch.stack(input_ts, dim=0),
|
||||
torch.stack(input_padding, dim=0),
|
||||
torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
|
||||
)
|
||||
|
||||
def _postprocess_output(
|
||||
self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""Postprocess output of stacked transformer."""
|
||||
|
||||
# B x N x (H.Q)
|
||||
output_ts = self.horizon_ff_layer(model_output)
|
||||
|
||||
# Reshape using view
|
||||
b, n, _ = output_ts.shape
|
||||
output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)
|
||||
|
||||
mu, sigma = stats
|
||||
return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
|
||||
|
||||
def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
losses = []
|
||||
for i, q in enumerate(self.config.quantiles):
|
||||
errors = targets - predictions[..., i]
|
||||
loss = torch.max((q - 1) * errors, q * errors)
|
||||
losses.append(loss.mean())
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TimesFmOutputForPrediction,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
past_values: Sequence[torch.Tensor],
|
||||
freq: Optional[Sequence[Union[torch.Tensor, int]]] = None,
|
||||
window_size: Optional[int] = None,
|
||||
future_values: Optional[torch.Tensor] = None,
|
||||
forecast_context_len: Optional[int] = None,
|
||||
return_forecast_on_context: bool = False,
|
||||
truncate_negative: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> TimesFmOutputForPrediction:
|
||||
r"""
|
||||
window_size (`int`, *optional*):
|
||||
Window size of trend + residual decomposition. If None then we do not do decomposition.
|
||||
future_values (`torch.Tensor`, *optional*):
|
||||
Optional future time series values to be used for loss computation.
|
||||
forecast_context_len (`int`, *optional*):
|
||||
Optional max context length.
|
||||
return_forecast_on_context (`bool`, *optional*):
|
||||
True to return the forecast on the context when available, i.e. after the first input patch.
|
||||
truncate_negative (`bool`, *optional*):
|
||||
Truncate to only non-negative values if any of the contexts have non-negative values,
|
||||
otherwise do nothing.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether to output the attentions.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether to output the hidden states.
|
||||
|
||||
Returns:
|
||||
A TimesFmOutputForPrediction object or a tuple containing:
|
||||
- the mean forecast of size (# past_values, # forecast horizon),
|
||||
- the full forecast (mean + quantiles) of size
|
||||
(# past_values, # forecast horizon, 1 + # quantiles).
|
||||
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
|
||||
"""
|
||||
if forecast_context_len is None:
|
||||
fcontext_len = self.context_len
|
||||
else:
|
||||
fcontext_len = forecast_context_len
|
||||
|
||||
# Get device from first input tensor
|
||||
device = past_values[0].device
|
||||
|
||||
# Truncate inputs to forecast_context_len
|
||||
inputs = [ts[-fcontext_len:] for ts in past_values]
|
||||
inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
|
||||
|
||||
if window_size is not None:
|
||||
new_inputs = []
|
||||
new_freqs = []
|
||||
for i, ts in enumerate(inputs):
|
||||
new_inputs.extend(self._timesfm_moving_average(ts, window_size))
|
||||
if freq is not None:
|
||||
new_freqs.extend([freq[i]] * 2)
|
||||
inputs = new_inputs
|
||||
if freq is not None:
|
||||
freq = new_freqs
|
||||
|
||||
if freq is None:
|
||||
logger.info("No frequency provided via `freq`. Default to high (0).")
|
||||
freq = [0] * len(inputs)
|
||||
|
||||
if output_attentions is None:
|
||||
output_attentions = self.config.output_attentions
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.config.output_hidden_states
|
||||
|
||||
input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
|
||||
# Move tensors to the same device as input
|
||||
input_ts = input_ts.to(device)
|
||||
input_padding = input_padding.to(device)
|
||||
inp_freq = inp_freq.to(device)
|
||||
|
||||
final_out = input_ts
|
||||
context_len = final_out.shape[1]
|
||||
full_outputs = []
|
||||
|
||||
if input_padding.shape[1] != final_out.shape[1] + self.horizon_len:
|
||||
raise ValueError(
|
||||
"Length of paddings must match length of input + horizon_len:"
|
||||
f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}"
|
||||
)
|
||||
output_patch_len = self.config.horizon_length
|
||||
|
||||
num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len
|
||||
for step_index in range(num_decode_patches):
|
||||
current_padding = input_padding[:, 0 : final_out.shape[1]]
|
||||
input_ts = final_out[:, -fcontext_len:]
|
||||
input_padding = current_padding[:, -fcontext_len:]
|
||||
decoder_output = self.decoder(
|
||||
past_values=input_ts,
|
||||
past_values_padding=input_padding,
|
||||
freq=inp_freq,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
fprop_outputs = self._postprocess_output(
|
||||
decoder_output.last_hidden_state,
|
||||
(decoder_output.loc, decoder_output.scale),
|
||||
)
|
||||
|
||||
if return_forecast_on_context and step_index == 0:
|
||||
# For the first decodings step, collect the model forecast on the
|
||||
# context except the unavailable first input batch forecast.
|
||||
new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :]
|
||||
# We have to use reshape and not view for non-contiguous memory
|
||||
new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3))
|
||||
|
||||
full_outputs.append(new_full_ts)
|
||||
|
||||
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
|
||||
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
|
||||
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
|
||||
# (full batch, last patch, output_patch_len, all output indices)
|
||||
full_outputs.append(new_full_ts)
|
||||
final_out = torch.concatenate([final_out, new_ts], axis=-1)
|
||||
|
||||
if return_forecast_on_context:
|
||||
# `full_outputs` indexing starts at after the first input patch.
|
||||
full_outputs = torch.concatenate(full_outputs, axis=1)[
|
||||
:, : (context_len - self.config.patch_length + self.horizon_len), :
|
||||
]
|
||||
else:
|
||||
# `full_outputs` indexing starts at the forecast horizon.
|
||||
full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :]
|
||||
|
||||
mean_outputs = full_outputs[:, :, 0]
|
||||
if window_size is not None:
|
||||
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
|
||||
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
|
||||
if inp_min >= 0 and truncate_negative:
|
||||
mean_outputs = torch.maximum(mean_outputs, 0.0)
|
||||
full_outputs = torch.maximum(full_outputs, 0.0)
|
||||
|
||||
loss = None
|
||||
if future_values is not None:
|
||||
mse_loss = F.mse_loss(mean_outputs, future_values)
|
||||
quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values)
|
||||
loss = mse_loss + quantile_loss
|
||||
|
||||
return TimesFmOutputForPrediction(
|
||||
last_hidden_state=decoder_output.last_hidden_state,
|
||||
attentions=decoder_output.attentions if output_attentions else None,
|
||||
hidden_states=decoder_output.hidden_states if output_hidden_states else None,
|
||||
mean_predictions=mean_outputs,
|
||||
full_predictions=full_outputs,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
|
||||
"""Calculates the moving average using PyTorch's convolution function."""
|
||||
# Pad with zeros to handle initial window positions
|
||||
arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
|
||||
# Create a convolution kernel
|
||||
kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
|
||||
# Apply convolution to calculate the moving average
|
||||
smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
|
||||
return [smoothed_arr, arr - smoothed_arr]
|
||||
|
||||
|
||||
__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]
|
0
tests/models/timesfm/__init__.py
Normal file
0
tests/models/timesfm/__init__.py
Normal file
197
tests/models/timesfm/test_modeling_timesfm.py
Normal file
197
tests/models/timesfm/test_modeling_timesfm.py
Normal file
@ -0,0 +1,197 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google LLC and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import TimesFmConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.utils import is_torch_fx_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
if is_torch_fx_available():
|
||||
pass
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import TimesFmModelForPrediction
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
class TimesFmModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
patch_length: int = 32,
|
||||
context_length: int = 512,
|
||||
horizon_length: int = 128,
|
||||
freq_size: int = 3,
|
||||
num_hidden_layers: int = 1,
|
||||
hidden_size: int = 16,
|
||||
intermediate_size: int = 32,
|
||||
head_dim: int = 8,
|
||||
num_heads: int = 2,
|
||||
tolerance: float = 1e-6,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
pad_val: float = 1123581321.0,
|
||||
use_positional_embedding: bool = True,
|
||||
initializer_factor: float = 0.0,
|
||||
is_training: bool = False,
|
||||
batch_size: int = 3,
|
||||
):
|
||||
self.parent = parent
|
||||
self.patch_length = patch_length
|
||||
self.context_length = context_length
|
||||
self.horizon_length = horizon_length
|
||||
self.quantiles = quantiles
|
||||
self.pad_val = pad_val
|
||||
self.freq_size = freq_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.head_dim = head_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_heads
|
||||
self.tolerance = tolerance
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_positional_embedding = use_positional_embedding
|
||||
self.initializer_factor = initializer_factor
|
||||
self.is_training = is_training
|
||||
self.batch_size = batch_size
|
||||
|
||||
# The size of test input
|
||||
self.seq_length = context_length // patch_length
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def get_config(self):
|
||||
return TimesFmConfig(
|
||||
patch_length=self.patch_length,
|
||||
context_length=self.context_length,
|
||||
horizon_length=self.horizon_length,
|
||||
quantiles=self.quantiles,
|
||||
pad_val=self.pad_val,
|
||||
freq_size=self.freq_size,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
head_dim=self.head_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
tolerance=self.tolerance,
|
||||
rms_norm_eps=self.rms_norm_eps,
|
||||
use_positional_embedding=self.use_positional_embedding,
|
||||
initializer_factor=self.initializer_factor,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
return self.get_config()
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
forecast_input = [
|
||||
torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
|
||||
torch.tensor(np.cos(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
|
||||
torch.tensor(np.tan(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
|
||||
]
|
||||
frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device)
|
||||
|
||||
return (self.get_config(), torch.stack(forecast_input, dim=0), frequency_input)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
(config, forecast_input, frequency_input) = self.prepare_config_and_inputs()
|
||||
|
||||
inputs_dict = {
|
||||
"past_values": forecast_input,
|
||||
"freq": frequency_input,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class TimesFmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else ()
|
||||
all_generative_model_classes = ()
|
||||
all_parallelizable_model_classes = ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_model_parallel = False
|
||||
is_encoder_decoder = False
|
||||
test_inputs_embeds = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TimesFmModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=TimesFmConfig)
|
||||
|
||||
def test_create_and_run_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TimesFmModelForPrediction(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
results = model(**inputs_dict)
|
||||
assert results.mean_predictions is not None
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because of masks")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Model does not have input embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Model does not have head mask")
|
||||
def test_headmasking(self):
|
||||
pass
|
||||
|
||||
# the main input name is `inputs`
|
||||
def test_model_main_input_name(self):
|
||||
model_signature = inspect.signature(getattr(TimesFmModelForPrediction, "forward"))
|
||||
# The main input is the name of the argument after `self`
|
||||
observed_main_input_name = list(model_signature.parameters.keys())[1]
|
||||
self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name)
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
class TimesFmModelIntegrationTests(unittest.TestCase):
|
||||
def test_inference_no_head(self):
|
||||
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch", revision="refs/pr/7").to(
|
||||
torch_device
|
||||
)
|
||||
forecast_input = [
|
||||
np.sin(np.linspace(0, 20, 100)),
|
||||
np.sin(np.linspace(0, 20, 200)),
|
||||
np.sin(np.linspace(0, 20, 400)),
|
||||
]
|
||||
forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32, device=torch_device) for ts in forecast_input]
|
||||
frequency_input = [0, 1, 2]
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(past_values=forecast_input_tensor, freq=frequency_input).last_hidden_state
|
||||
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
torch.Size([3, model.config.context_length // model.config.patch_length, model.config.hidden_size]),
|
||||
)
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932]],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
|
@ -156,6 +156,7 @@ IGNORE_NON_TESTED = (
|
||||
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
"Emu3VQVAE", # Building part of bigger (tested) model
|
||||
"Emu3TextModel", # Building part of bigger (tested) model
|
||||
"TimesFmModel", # Building part of bigger (tested) model
|
||||
]
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user