Add TimesFM Time Series Forecasting Model (#34082)

* initial documentation

* rename mask to attention_mask

* smaller tests

* fixup

* fix copies

* move to time series section

* sort docs

* isort fix

* batch_size is not a configuration

* rename to TimesFMModelForPrediction

* initial script

* add check_outputs

* remove dropout_rate

* works with torch.Tensor inputs

* rename script

* fix docstrings

* fix freq when window_size is given

* add loss

* fix _quantile_loss

* formatting

* fix isort

* add weight init

* add support for sdpa and flash_attention_2

* fixes for flash_attention

* formatting

* remove flash_attention

* fix tests

* fix file name

* fix quantile loss

* added initial TimesFMModelIntegrationTests

* fix formatting

* fix import order

* fix _quantile_loss

* add doc for SDPA

* use timesfm 2.0

* bug fix in timesfm decode function.

* compare mean forecasts

* refactor type hints, use CamelCase

* consolidate decode func

* more readable code for weight conversion

* fix-copies

* simpler init

* renaem TimesFmMLP

* use T5LayerNorm

* fix tests

* use initializer_range

* TimesFmModel instead of TimesFmDecoder

* TimesFmPositionalEmbedding takes config for its init

* 2.0-500m-pytorch default configs

* use TimesFmModel

* fix formatting

* ignore TimesFmModel for testing

* fix docstring

* override generate as its not needed

* add doc strings

* fix logging

* add docstrings to output data classes

* initial copy from t5

* added config and attention layers

* add TimesFMPositionalEmbedding

* calcuate scale_factor once

* add more configs and TimesFMResidualBlock

* fix input_dims

* standardize code format with black

* remove unneeded modules

* TimesFM Model

* order of imports

* copy from Google official implementation

* remove covariate forecasting

* Adapting TimesFM to HF format

* restructing in progress

* adapted to HF convention

* timesfm test

* the model runs

* fixing unit tests

* fixing unit tests in progress

* add post_init

* do not change TimesFMOutput

* fixing unit tests

* all unit tests passed

* remove timesfm_layers

* add intermediate_size and initialize with config

* initial documentation

* rename mask to attention_mask

* smaller tests

* fixup

* fix copies

* move to time series section

* sort docs

* isort fix

* batch_size is not a configuration

* rename to TimesFMModelForPrediction

* initial script

* add check_outputs

* remove dropout_rate

* works with torch.Tensor inputs

* rename script

* fix docstrings

* fix freq when window_size is given

* add loss

* fix _quantile_loss

* formatting

* fix isort

* add weight init

* add support for sdpa and flash_attention_2

* fixes for flash_attention

* formatting

* remove flash_attention

* fix tests

* fix file name

* fix quantile loss

* added initial TimesFMModelIntegrationTests

* fix formatting

* fix import order

* fix _quantile_loss

* add doc for SDPA

* use timesfm 2.0

* bug fix in timesfm decode function.

* compare mean forecasts

* refactor type hints, use CamelCase

* consolidate decode func

* more readable code for weight conversion

* fix-copies

* simpler init

* renaem TimesFmMLP

* use T5LayerNorm

* fix tests

* use initializer_range

* TimesFmModel instead of TimesFmDecoder

* TimesFmPositionalEmbedding takes config for its init

* 2.0-500m-pytorch default configs

* use TimesFmModel

* fix formatting

* ignore TimesFmModel for testing

* fix docstring

* override generate as its not needed

* add doc strings

* fix logging

* add docstrings to output data classes

* add _CHECKPOINT_FOR_DOC

* fix comments

* Revert "fix comments"

This reverts commit 8deeb3e191.

* add _prepare_4d_attention_mask

* we do not have generative model classes

* use Cache

* return past_key_values

* modules initialized with config only

* update year

* Update docs/source/en/model_doc/timesfm.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add layer_idx to cache

* modular timesfm

* fix test

* unwrap sequential class

* fix toctree

* remove TimesFmOnnxConfig

* fix modular

* remove TimesFmStackedDecoder

* split qkv layer into individual layers

* rename projection layers

* use ALL_ATTENTION_FUNCTIONS

* is_causal is True

* rename config

* does not support flash_attn_2

* formatting

* fix typo in docsstring

* rename inputs

* add time series mapping

* Update src/transformers/models/olmo2/modeling_olmo2.py

* Update src/transformers/models/moonshine/modeling_moonshine.py

* use updated arguments

* fix class name

* add MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING

* isort

* consolidate _preprocess into forward

* fix a typo

* fix a typo

* fix toc

* fix modular

* remove aaserts

* use self.config._attn_implementation

* move to _postprocess_output

* remove timesfm_get_large_negative_number

* use view unstead of multiple unsqueeze

* make helpers static methods of the Model

* use to_tuple

* use to_tuple if not return_dict

* remove unused intitialization block as its incorporated in nn.Linear

* remove unused num_key_value_groups

* use the same convention as the masking method

* update modular

* do not use unsqueeze

* use view instead of unsqueeze

* use buffer for inv_timescales

* formatting

* modular conversion

* remove unneeded intialization

* add missing docstrings

* remove cache

* use simple_eager_attention_forward

* support tp_plan

* support for flex and flash attention masks

* Revert "support for flex and flash attention masks"

This reverts commit def36c4fcf.

* fix device

* fix tests on gpu

* remove unsued large model test

* removed unneeded comments

* add example usage

* fix style

* add import

* Update docs/source/en/model_doc/timesfm.md

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>

* inherit from LlamaRMSNorm

* use can_return_tuple decorator

* remvoe return_dict

* fix year

* Update docs/source/en/model_doc/timesfm.md

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>

* pretrained does not inherit from GenerationMixin

* use model for integration test

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Rajat Sen <rsen91@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Jinan Zhou 2025-04-16 06:00:53 -07:00 committed by GitHub
parent 8669c016d2
commit a91020aed0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2507 additions and 0 deletions

View File

@ -1059,6 +1059,8 @@
title: PatchTST
- local: model_doc/time_series_transformer
title: Time Series Transformer
- local: model_doc/timesfm
title: TimesFM
title: Time series models
- sections:
- local: model_doc/graphormer

View File

@ -0,0 +1,88 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# TimesFM
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
## Overview
TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model proposed in [A decoder-only foundation model for time-series forecasting](https://huggingface.co/papers/2310.10688) by Abhimanyu Das, Weihao Kong, Rajat Sen, and Yichen Zhou. It is a decoder only model that uses non-overlapping patches of time-series data as input and outputs some output patch length prediction in an autoregressive fashion.
The abstract from the paper is the following:
*Motivated by recent advances in large language models for Natural Language Processing (NLP), we design a time-series foundation model for forecasting whose out-of-the-box zero-shot performance on a variety of public datasets comes close to the accuracy of state-of-the-art supervised forecasting models for each individual dataset. Our model is based on pretraining a patched-decoder style attention model on a large time-series corpus, and can work well across different forecasting history lengths, prediction lengths and temporal granularities.*
This model was contributed by [kashif](https://huggingface.co/kashif).
The original code can be found [here](https://github.com/google-research/timesfm).
To use the model:
```python
import torch
from transformers import TimesFmModelForPrediction
model = TimesFmModelForPrediction.from_pretrained(
"google/timesfm-2.0-500m-pytorch",
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="cuda" if torch.cuda.is_available() else None
)
# Create dummy inputs
forecast_input = [
np.sin(np.linspace(0, 20, 100)),
np.sin(np.linspace(0, 20, 200)),
np.sin(np.linspace(0, 20, 400)),
]
frequency_input = [0, 1, 2]
# Convert inputs to sequence of tensors
forecast_input_tensor = [
torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu")
for ts in forecast_input
]
frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Get predictions from the pre-trained model
with torch.no_grad():
outputs = model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True)
point_forecast_conv = outputs.mean_predictions.float().cpu().numpy()
quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy()
```
## TimesFmConfig
[[autodoc]] TimesFmConfig
## TimesFmModel
[[autodoc]] TimesFmModel
- forward
## TimesFmModelForPrediction
[[autodoc]] TimesFmModelForPrediction
- forward

View File

@ -279,6 +279,7 @@ if TYPE_CHECKING:
from .tapas import *
from .textnet import *
from .time_series_transformer import *
from .timesfm import *
from .timesformer import *
from .timm_backbone import *
from .timm_wrapper import *

View File

@ -313,6 +313,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("tapas", "TapasConfig"),
("textnet", "TextNetConfig"),
("time_series_transformer", "TimeSeriesTransformerConfig"),
("timesfm", "TimesFmConfig"),
("timesformer", "TimesformerConfig"),
("timm_backbone", "TimmBackboneConfig"),
("timm_wrapper", "TimmWrapperConfig"),
@ -681,6 +682,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("tapex", "TAPEX"),
("textnet", "TextNet"),
("time_series_transformer", "Time Series Transformer"),
("timesfm", "TimesFm"),
("timesformer", "TimeSformer"),
("timm_backbone", "TimmBackbone"),
("timm_wrapper", "TimmWrapperModel"),

View File

@ -281,6 +281,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("tapas", "TapasModel"),
("textnet", "TextNetModel"),
("time_series_transformer", "TimeSeriesTransformerModel"),
("timesfm", "TimesFmModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
("timm_wrapper", "TimmWrapperModel"),
@ -1542,6 +1543,12 @@ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("timesfm", "TimesFmModelForPrediction"),
]
)
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
[
("swin2sr", "Swin2SRForImageSuperResolution"),
@ -1650,6 +1657,10 @@ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
)
MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES
)
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
@ -1820,6 +1831,15 @@ AutoModelForSemanticSegmentation = auto_class_update(
)
class AutoModelForTimeSeriesPrediction(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING
AutoModelForTimeSeriesPrediction = auto_class_update(
AutoModelForTimeSeriesPrediction, head_doc="time-series prediction"
)
class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
@ -1994,6 +2014,7 @@ __all__ = [
"MODEL_FOR_TEXT_ENCODING_MAPPING",
"MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
"MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
"MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",

View File

@ -0,0 +1,27 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_timesfm import *
from .modeling_timesfm import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2025 Google LLC and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TimesFM model configuration"""
from typing import List
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class TimesFmConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to
instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the TimesFM
[google/timesfm-2.0-500m-pytorch](https://huggingface.co/google/timesfm-2.0-500m-pytorch) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Arguments:
patch_length (`int`, *optional*, defaults to 32):
The length of one patch in the input sequence.
context_length (`int`, *optional*, defaults to 512):
The length of the input context.
horizon_length (`int`, *optional*, defaults to 128):
The length of the prediction horizon.
freq_size (`int`, *optional*, defaults to 3):
The number of frequency embeddings.
num_hidden_layers (`int`, *optional*, defaults to 50):
Number of Transformer layers.
hidden_size (`int`, *optional*, defaults to 1280):
Size of the hidden layers in the feed-forward networks.
intermediate_size (`int`, *optional*, defaults to 1280):
Dimension of the MLP representations.
head_dim (`int`, *optional*, defaults to 80):
Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
be defined as `num_attention_heads * head_dim`.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
tolerance (`float`, *optional*, defaults to 1e-06):
The tolerance for the quantile loss.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the RMS normalization layers.
quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]`):
The quantiles to predict.
pad_val (`float`, *optional*, defaults to 1123581321.0):
The value used to pad the predictions.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention scores.
use_positional_embedding (`bool`, *optional*, defaults to `False`):
Whether to add positional embeddings.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
min_timescale (`int`, *optional*, defaults to 1):
The start of the geometric positional index. Determines the periodicity of
the added signal.
max_timescale (`int`, *optional*, defaults to 10000):
The end of the geometric positional index. Determines the frequency of the
added signal.
"""
model_type = "timesfm"
keys_to_ignore_at_inference = []
is_encoder_decoder = False
def __init__(
self,
patch_length: int = 32,
context_length: int = 512,
horizon_length: int = 128,
freq_size: int = 3,
num_hidden_layers: int = 50,
hidden_size: int = 1280,
intermediate_size: int = 1280,
head_dim: int = 80,
num_attention_heads: int = 16,
tolerance: float = 1e-6,
rms_norm_eps: float = 1e-6,
quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
pad_val: float = 1123581321.0,
attention_dropout: float = 0.0,
use_positional_embedding: bool = False,
initializer_range: float = 0.02,
min_timescale: int = 1,
max_timescale: int = 10_000,
**kwargs,
):
self.patch_length = patch_length
self.context_length = context_length
self.horizon_length = horizon_length
self.quantiles = quantiles
self.pad_val = pad_val
self.freq_size = freq_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.tolerance = tolerance
self.rms_norm_eps = rms_norm_eps
self.attention_dropout = attention_dropout
self.use_positional_embedding = use_positional_embedding
self.initializer_range = initializer_range
self.min_timescale = min_timescale
self.max_timescale = max_timescale
super().__init__(
is_encoder_decoder=self.is_encoder_decoder,
**kwargs,
)
__all__ = ["TimesFmConfig"]

View File

@ -0,0 +1,275 @@
import argparse
import os
import re
import shutil
import numpy as np
import timesfm
import torch
from transformers import TimesFmConfig, TimesFmModelForPrediction
"""
Sample usage:
```
python src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py \
--output_dir /output/path
```
"""
def get_nested_attr(obj, key):
"""Recursively retrieves an attribute from an object, handling list/tuple indexing if present."""
parts = key.split(".")
for part in parts:
match = re.match(r"(.*)\[(\d+)\]", part) # Handle list indexing like `layers[0]`
if match:
attr_name, index = match.groups()
obj = getattr(obj, attr_name)[int(index)] # Access list/tuple element
else:
obj = getattr(obj, part) # Regular attribute access
return obj
def write_model(model_path, safe_serialization=True, huggingface_repo_id="google/timesfm-2.0-500m-pytorch"):
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="cuda" if torch.cuda.is_available() else "cpu",
per_core_batch_size=32,
horizon_len=128,
input_patch_len=32,
output_patch_len=128,
num_layers=50,
model_dims=1280,
use_positional_embedding=False,
),
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
)
timesfm_config = TimesFmConfig(
patch_length=tfm.hparams.input_patch_len,
context_length=tfm.hparams.context_len,
horizon_length=tfm.hparams.horizon_len,
num_hidden_layers=tfm.hparams.num_layers,
hidden_size=tfm.hparams.model_dims,
intermediate_size=tfm.hparams.model_dims,
head_dim=tfm.hparams.model_dims // tfm.hparams.num_heads,
num_attention_heads=tfm.hparams.num_heads,
use_positional_embedding=tfm.hparams.use_positional_embedding,
)
timesfm_config.save_pretrained(tmp_model_path)
timesfm_model = TimesFmModelForPrediction(timesfm_config)
# copy the weights from the original model to the new model making
original_model = tfm._model
# mapping of the layers from the original model to the transformer model
MODEL_LAYER_MAPPING = {
"input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.input_layer.weight",
"input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.input_layer.bias",
"input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight",
"input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias",
"input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight",
"input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias",
"freq_emb.weight": "decoder.freq_emb.weight",
"horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.input_layer.weight",
"horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.input_layer.bias",
"horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight",
"horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias",
"horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight",
"horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias",
}
TRANSFORMER_LAYER_MAPPING = {
"stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.layers[{i}].self_attn.qkv_proj.weight",
"stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.layers[{i}].self_attn.qkv_proj.bias",
"stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.layers[{i}].self_attn.o_proj.weight",
"stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.layers[{i}].self_attn.o_proj.bias",
"stacked_transformer.layers[{i}].self_attn.scaling": "decoder.layers[{i}].self_attn.scaling",
"stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.layers[{i}].mlp.gate_proj.weight",
"stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.layers[{i}].mlp.gate_proj.bias",
"stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.layers[{i}].mlp.down_proj.weight",
"stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.layers[{i}].mlp.down_proj.bias",
"stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.layers[{i}].mlp.layer_norm.weight",
"stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.layers[{i}].mlp.layer_norm.bias",
"stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.layers[{i}].input_layernorm.weight",
}
for old_key, new_key in MODEL_LAYER_MAPPING.items():
try:
old_attr = get_nested_attr(original_model, old_key) # Get tensor from original model
new_attr = get_nested_attr(timesfm_model, new_key) # Get corresponding attribute in new model
new_attr.data.copy_(old_attr.data) # Copy data
except AttributeError:
print(f"Skipping {old_key} (not found in original model).")
num_layers = len(timesfm_model.decoder.layers)
for i in range(num_layers):
for old_template, new_template in TRANSFORMER_LAYER_MAPPING.items():
old_key = old_template.format(i=i)
new_key = new_template.format(i=i)
try:
# Get tensor from original model
old_attr = get_nested_attr(original_model, old_key)
if "qkv_proj" in old_key:
# Split the tensor into q, k, v projections
q_proj, k_proj, v_proj = (
old_attr[: tfm.hparams.model_dims, ...],
old_attr[tfm.hparams.model_dims : tfm.hparams.model_dims * 2, ...],
old_attr[tfm.hparams.model_dims * 2 :, ...],
)
# Get corresponding attribute in new model
q_key = new_key.replace("qkv_proj", "q_proj")
q_attr = get_nested_attr(timesfm_model, q_key)
q_attr.data.copy_(q_proj.data) # Copy data
k_key = new_key.replace("qkv_proj", "k_proj")
k_attr = get_nested_attr(timesfm_model, k_key)
k_attr.data.copy_(k_proj.data) # Copy data
v_key = new_key.replace("qkv_proj", "v_proj")
v_attr = get_nested_attr(timesfm_model, v_key)
v_attr.data.copy_(v_proj.data) # Copy data
else:
# Get corresponding attribute in new model
new_attr = get_nested_attr(timesfm_model, new_key)
new_attr.data.copy_(old_attr.data) # Copy data
except AttributeError:
print(f"Skipping {old_key} (not found in original model).")
timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization)
shutil.rmtree(tmp_model_path)
def check_outputs(model_path, huggingface_repo_id):
"""Compares outputs between original and converted models."""
print("\nChecking model outputs...")
# Load original model
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="cuda" if torch.cuda.is_available() else "cpu",
per_core_batch_size=32,
horizon_len=128,
input_patch_len=32,
output_patch_len=128,
num_layers=50,
model_dims=1280,
use_positional_embedding=False,
point_forecast_mode="mean",
),
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
)
# Load converted model
converted_model = TimesFmModelForPrediction.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
).to("cuda" if torch.cuda.is_available() else "cpu")
converted_model.eval() # Set to evaluation mode
# Create test inputs
forecast_input = [
np.sin(np.linspace(0, 20, 100)),
np.sin(np.linspace(0, 20, 200)),
np.sin(np.linspace(0, 20, 400)),
]
frequency_input = [0, 1, 2]
# Get predictions from original model
point_forecast_orig, quantile_forecast_orig = tfm.forecast(
forecast_input,
freq=frequency_input,
)
# Convert inputs to sequence of tensors
forecast_input_tensor = [
torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu")
for ts in forecast_input
]
frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Get predictions from converted model
with torch.no_grad():
outputs = converted_model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True)
point_forecast_conv = outputs.mean_predictions.float().cpu().numpy()
quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy()
# Compare outputs
point_forecast_diff = np.abs(point_forecast_orig - point_forecast_conv)
quantile_forecast_diff = np.abs(quantile_forecast_orig - quantile_forecast_conv)
max_point_diff = point_forecast_diff.max()
mean_point_diff = point_forecast_diff.mean()
max_quantile_diff = quantile_forecast_diff.max()
mean_quantile_diff = quantile_forecast_diff.mean()
print("\nOutput comparison:")
print(f"Point forecast - Max difference: {max_point_diff:.6f}")
print(f"Point forecast - Mean difference: {mean_point_diff:.6f}")
print(f"Quantile forecast - Max difference: {max_quantile_diff:.6f}")
print(f"Quantile forecast - Mean difference: {mean_quantile_diff:.6f}")
# Define acceptable thresholds
POINT_THRESHOLD = 1e-5
QUANTILE_THRESHOLD = 1e-5
if max_point_diff > POINT_THRESHOLD or max_quantile_diff > QUANTILE_THRESHOLD:
raise ValueError(
f"Output mismatch detected!\n"
f"Point forecast max diff: {max_point_diff} (threshold: {POINT_THRESHOLD})\n"
f"Quantile forecast max diff: {max_quantile_diff} (threshold: {QUANTILE_THRESHOLD})"
)
print("\n✓ All outputs match within acceptable tolerance!")
# Optional: Print shapes for verification
print("\nOutput shapes:")
print(f"Original point forecast: {point_forecast_orig.shape}")
print(f"Converted point forecast: {point_forecast_conv.shape}")
print(f"Original quantile forecast: {quantile_forecast_orig.shape}")
print(f"Converted quantile forecast: {quantile_forecast_conv.shape}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--output_dir",
required=True,
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`."
)
parser.add_argument(
"--huggingface_repo_id",
type=str,
default="google/timesfm-2.0-500m-pytorch",
help="The Hugging Face repository ID to use for the model.",
)
args = parser.parse_args()
# if the saved model file exists, skip the conversion
if os.path.exists(
os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "pytorch_model.bin")
):
print(f"Model already exists in {args.output_dir}, skipping conversion.")
else:
write_model(
model_path=args.output_dir,
safe_serialization=args.safe_serialization,
huggingface_repo_id=args.huggingface_repo_id,
)
check_outputs(args.output_dir, args.huggingface_repo_id)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,904 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/timesfm/modular_timesfm.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_timesfm.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Google LLC and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Callable, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
from .configuration_timesfm import TimesFmConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch"
_CONFIG_FOR_DOC = "TimesFmConfig"
@dataclass
class TimesFmOutput(BaseModelOutput):
"""
Args:
loc (`torch.Tensor` of shape `(batch_size, )`):
The mean of the time series inputs.
scale (`torch.Tensor` of shape `(batch_size,)`):
The scale of the time series inputs.
"""
loc: Optional[torch.Tensor] = None
scale: Optional[torch.Tensor] = None
@dataclass
class TimesFmOutputForPrediction(BaseModelOutput):
"""
Args:
mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The mean predictions of the time series.
full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The full predictions of the time series including the mean and the quantiles.
loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
The loss of the TimesFM model.
"""
mean_predictions: Optional[torch.Tensor] = None
full_predictions: Optional[torch.Tensor] = None
loss: Optional[Union[torch.Tensor, float]] = None
class TimesFmMLP(nn.Module):
"""Pax MLP in pytorch."""
def __init__(self, config: TimesFmConfig):
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
self.down_proj = nn.Linear(intermediate_size, hidden_size)
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
def forward(self, x, paddings=None):
gate_inp = self.layer_norm(x)
gate = self.gate_proj(gate_inp)
gate = F.relu(gate)
outputs = self.down_proj(gate)
if paddings is not None:
outputs = outputs * (1.0 - paddings[:, :, None])
return outputs + x
class TimesFmResidualBlock(nn.Module):
"""TimesFM residual block."""
def __init__(self, input_dims, hidden_dims, output_dims):
super().__init__()
self.input_dims = input_dims
self.hidden_dims = hidden_dims
self.output_dims = output_dims
self.input_layer = nn.Linear(input_dims, hidden_dims)
self.activation = nn.SiLU()
self.output_layer = nn.Linear(hidden_dims, output_dims)
self.residual_layer = nn.Linear(input_dims, output_dims)
def forward(self, x):
hidden = self.input_layer(x)
hidden = self.activation(hidden)
output = self.output_layer(hidden)
residual = self.residual_layer(x)
return output + residual
@use_kernel_forward_from_hub("RMSNorm")
class TimesFmRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
TimesFmRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class TimesFmPositionalEmbedding(nn.Module):
"""Generates position embedding for a given 1-d sequence."""
def __init__(self, config: TimesFmConfig):
super().__init__()
min_timescale = config.min_timescale
max_timescale = config.max_timescale
self.embedding_dims = config.hidden_size
num_timescales = self.embedding_dims // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
self.register_buffer(
"inv_timescales",
min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
)
def forward(self, seq_length=None, position=None):
"""Generates a Tensor of sinusoids with different frequencies.
Args:
seq_length: an optional Python int defining the output sequence length.
if the `position` argument is specified.
position: [B, seq_length], optional position for each token in the
sequence, only required when the sequence is packed.
Returns:
[B, seqlen, D] if `position` is specified, else [1, seqlen, D]
"""
if position is None and seq_length is None:
raise ValueError("Either position or seq_length must be provided")
if position is None:
# [1, seqlen]
position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
elif position.ndim != 2:
raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")
scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
# Padding to ensure correct embedding dimension
signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
return signal
def simple_eager_attention_forward(
module: nn.Module,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class TimesFmAttention(nn.Module):
"""Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query."""
def __init__(self, config: TimesFmConfig, layer_idx: int):
super().__init__()
self.config = config
self.is_causal = True
self.attention_dropout = config.attention_dropout
self.layer_idx = layer_idx
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_heads * self.head_dim
self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
def _scale_query(self, query: torch.Tensor) -> torch.Tensor:
scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
return query * scale[None, None, None, :]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states = self._scale_query(query_states)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
attention_interface: Callable = simple_eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=1.0,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class TimesFmDecoderLayer(nn.Module):
"""Transformer layer."""
def __init__(self, config: TimesFmConfig, layer_idx: int):
super().__init__()
self.self_attn = TimesFmAttention(config, layer_idx=layer_idx)
self.mlp = TimesFmMLP(config)
self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
paddings: torch.Tensor,
output_attentions: bool = False,
) -> tuple[Optional[torch.Tensor], torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, scores = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
# MLP
hidden_states = self.mlp(hidden_states, paddings=paddings)
return scores, hidden_states
TIMESFM_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`TimesFmConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
TIMESFM_START_DOCSTRING,
)
class TimesFmPreTrainedModel(PreTrainedModel):
"""handles the loading for all models."""
config_class = TimesFmConfig
base_model_prefix = "timesfm"
_no_split_modules = ["TimesFmDecoderLayer"]
main_input_name = "past_values"
_supports_sdpa = True
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, TimesFmRMSNorm):
nn.init.zeros_(module.weight)
elif isinstance(module, TimesFmAttention):
# Initialize scaling parameter
nn.init.ones_(module.scaling)
TIMESFM_INPUTS_DOCSTRING = r"""
Args:
past_values: list of time series forecast contexts. Each context time series
can be a torch Tensor of potentially different context lengths.
freq: frequency of each context time series in the inputs. 0 for high frequency
(default), 1 for medium, and 2 for low.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
"""
@add_start_docstrings(
"The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
TIMESFM_START_DOCSTRING,
)
class TimesFmModel(TimesFmPreTrainedModel):
"""Patched time-series decoder without any specific output layer."""
def __init__(self, config: TimesFmConfig):
super().__init__(config)
self.config = config
self.input_ff_layer = TimesFmResidualBlock(
input_dims=2 * config.patch_length,
output_dims=config.hidden_size,
hidden_dims=config.intermediate_size,
)
self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size)
self.layers = nn.ModuleList(
[TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
if self.config.use_positional_embedding:
self.position_emb = TimesFmPositionalEmbedding(config=config)
# Initialize weights and apply final processing
self.post_init()
def _forward_transform(
self, inputs: torch.Tensor, patched_pads: torch.Tensor
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Input is of shape [B, N, P]."""
mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
sigma = torch.where(
sigma < self.config.tolerance,
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
sigma,
)
# Normalize each patch
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
outputs = torch.where(
torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
outputs,
)
return outputs, (mu, sigma)
@can_return_tuple
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
def forward(
self,
past_values: torch.Tensor,
past_values_padding: torch.LongTensor,
freq: torch.Tensor,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> TimesFmOutput:
"""
past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The padding indicator of the time series.
"""
# Reshape into patches (using view for efficiency)
bsize = past_values.shape[0]
patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)
patched_inputs = torch.where(
torch.abs(patched_pads - 1.0) < self.config.tolerance,
torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),
patched_inputs,
)
patched_pads = torch.where(
torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
patched_pads,
)
patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
# B x N x D
patched_inputs = patched_inputs * (1.0 - patched_pads)
concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
model_input = self.input_ff_layer(concat_inputs)
# A patch should not be padded even if there is at least one zero.
patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result
if self.config.use_positional_embedding:
pos_emb = self.position_emb(model_input.shape[1])
pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb)
model_input += pos_emb
f_emb = self.freq_emb(freq) # B x 1 x D
model_input += f_emb
# Convert paddings to attention mask and combine with causal mask
hidden_states = model_input
attention_mask = self._prepare_4d_attention_mask(
attention_mask=patched_padding,
sequence_length=hidden_states.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
is_causal=True,
)
all_attentions = []
all_hidden_states = []
for layer in self.layers[: self.config.num_hidden_layers]:
scores, hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
paddings=patched_padding,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions.append(scores)
if output_hidden_states:
all_hidden_states.append(hidden_states)
if output_hidden_states:
all_hidden_states = [model_input] + all_hidden_states
else:
all_hidden_states = None
return TimesFmOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions if output_attentions else None,
loc=stats[0],
scale=stats[1],
)
@staticmethod
def _prepare_4d_attention_mask(
attention_mask: Optional[torch.Tensor],
sequence_length: int,
dtype: torch.dtype,
device: torch.device,
is_causal: bool = True,
) -> Optional[torch.Tensor]:
"""
Creates 4D attention mask and combines causal and padding masks if needed.
Args:
attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
sequence_length: Length of the sequence
dtype: Data type of the mask
device: Device of the mask
is_causal: Whether to apply causal masking
Returns:
4D attention mask of shape (batch_size, 1, seq_length, seq_length)
"""
# Get minimum value for the dtype
min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
# Handle padding mask
if attention_mask is not None:
# Convert 2D padding mask to 4D attention mask
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
attention_mask = attention_mask * min_value
# Create causal mask if needed
if is_causal:
causal_mask = torch.triu(
torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value,
diagonal=1,
)
causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length)
# Combine with padding mask if it exists
if attention_mask is not None:
attention_mask = torch.minimum(attention_mask, causal_mask)
else:
attention_mask = causal_mask
return attention_mask
@staticmethod
def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculates mean and standard deviation of `inputs` across axis 1.
It excludes values where `padding` is 1.
Args:
inputs: A PyTorch tensor of shape [b, n, p].
padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
Returns:
A tuple containing the mean and standard deviation.
We return the statistics of the first patch with more than three non-padded values.
"""
# Selecting the first patch with more than 3 unpadded values.
def _get_patch_index(arr: torch.Tensor):
indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
pad_sum = torch.sum(1 - padding, dim=2)
patch_indices = _get_patch_index(pad_sum)
bidxs = torch.arange(inputs.shape[0])
arr = inputs[bidxs, patch_indices, :]
pad = padding[bidxs, patch_indices, :]
# Create a mask where padding is 0
mask = 1 - pad
# Calculate the number of valid elements
num_valid_elements = torch.sum(mask, dim=1)
num_valid_elements = torch.where(
num_valid_elements == 0,
torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
num_valid_elements,
)
# Calculate the masked sum and squared sum
masked_sum = torch.sum(arr * mask, dim=1)
masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1)
# Calculate the masked mean and standard deviation
masked_mean = masked_sum / num_valid_elements
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
masked_var = torch.where(
masked_var < 0.0,
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
masked_var,
)
masked_std = torch.sqrt(masked_var)
return masked_mean, masked_std
@staticmethod
def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
"""Shifts rows of seq based on the first 0 in each row of the mask.
Args:
mask: mask tensor of shape [B, N]
seq: seq tensor of shape [B, N, P]
Returns:
The shifted sequence.
"""
batch_size, num_seq, feature_dim = seq.shape
new_mask: torch.BoolTensor = mask == 0
# Use argmax to find the first True value in each row
indices = new_mask.to(torch.int32).argmax(dim=1)
# Handle rows with all zeros
indices[~new_mask.any(dim=1)] = -1
# Create index ranges for each sequence in the batch
idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim)
# Calculate shifted indices for each element in each sequence
shifted_idx = (idx_range - indices[:, None, None]) % num_seq
# Gather values from seq using shifted indices
shifted_seq = seq.gather(1, shifted_idx)
return shifted_seq
class TimesFmModelForPrediction(TimesFmPreTrainedModel):
"""TimesFM model for quantile and mean prediction."""
def __init__(self, config: TimesFmConfig):
super().__init__(config)
self.config = config
self.context_len = config.context_length
self.horizon_len = config.horizon_length
self.decoder = TimesFmModel(config)
# quantile and mean output
self.horizon_ff_layer = TimesFmResidualBlock(
input_dims=config.hidden_size,
output_dims=config.horizon_length * (1 + len(config.quantiles)),
hidden_dims=config.intermediate_size,
)
# Initialize weights and apply final processing
self.post_init()
def _preprocess(
self, inputs: Sequence[torch.Tensor], freq: Sequence[int]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Formats and pads raw inputs to feed into the model.
This function both pads each time series to match the context length, and
pads the inputs to meet the SPMD shape requirement.
Args:
inputs: A list of 1d Tensors. Each Tensor is the context time series of
a single forecast task.
freq: list of frequencies
Returns:
A tuple of:
- the padded input time series to meet the model required context.
- the padding indicator.
- the number of padded examples for SPMD so that each core has the same
number (a multiple of `batch_size`) of examples.
"""
input_ts, input_padding, inp_freq = [], [], []
for i, ts in enumerate(inputs):
input_len = ts.shape[0]
padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
if input_len < self.context_len:
num_front_pad = self.context_len - input_len
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
elif input_len > self.context_len:
ts = ts[-self.context_len :]
padding = padding[-(self.context_len + self.horizon_len) :]
input_ts.append(ts)
input_padding.append(padding)
inp_freq.append(freq[i])
return (
torch.stack(input_ts, dim=0),
torch.stack(input_padding, dim=0),
torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
)
def _postprocess_output(
self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
"""Postprocess output of stacked transformer."""
# B x N x (H.Q)
output_ts = self.horizon_ff_layer(model_output)
# Reshape using view
b, n, _ = output_ts.shape
output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)
mu, sigma = stats
return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
losses = []
for i, q in enumerate(self.config.quantiles):
errors = targets - predictions[..., i]
loss = torch.max((q - 1) * errors, q * errors)
losses.append(loss.mean())
return torch.stack(losses).mean()
@can_return_tuple
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TimesFmOutputForPrediction,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
past_values: Sequence[torch.Tensor],
freq: Optional[Sequence[Union[torch.Tensor, int]]] = None,
window_size: Optional[int] = None,
future_values: Optional[torch.Tensor] = None,
forecast_context_len: Optional[int] = None,
return_forecast_on_context: bool = False,
truncate_negative: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> TimesFmOutputForPrediction:
r"""
window_size (`int`, *optional*):
Window size of trend + residual decomposition. If None then we do not do decomposition.
future_values (`torch.Tensor`, *optional*):
Optional future time series values to be used for loss computation.
forecast_context_len (`int`, *optional*):
Optional max context length.
return_forecast_on_context (`bool`, *optional*):
True to return the forecast on the context when available, i.e. after the first input patch.
truncate_negative (`bool`, *optional*):
Truncate to only non-negative values if any of the contexts have non-negative values,
otherwise do nothing.
output_attentions (`bool`, *optional*):
Whether to output the attentions.
output_hidden_states (`bool`, *optional*):
Whether to output the hidden states.
Returns:
A TimesFmOutputForPrediction object or a tuple containing:
- the mean forecast of size (# past_values, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# past_values, # forecast horizon, 1 + # quantiles).
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
"""
if forecast_context_len is None:
fcontext_len = self.context_len
else:
fcontext_len = forecast_context_len
# Get device from first input tensor
device = past_values[0].device
# Truncate inputs to forecast_context_len
inputs = [ts[-fcontext_len:] for ts in past_values]
inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
if window_size is not None:
new_inputs = []
new_freqs = []
for i, ts in enumerate(inputs):
new_inputs.extend(self._timesfm_moving_average(ts, window_size))
if freq is not None:
new_freqs.extend([freq[i]] * 2)
inputs = new_inputs
if freq is not None:
freq = new_freqs
if freq is None:
logger.info("No frequency provided via `freq`. Default to high (0).")
freq = [0] * len(inputs)
if output_attentions is None:
output_attentions = self.config.output_attentions
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
# Move tensors to the same device as input
input_ts = input_ts.to(device)
input_padding = input_padding.to(device)
inp_freq = inp_freq.to(device)
final_out = input_ts
context_len = final_out.shape[1]
full_outputs = []
if input_padding.shape[1] != final_out.shape[1] + self.horizon_len:
raise ValueError(
"Length of paddings must match length of input + horizon_len:"
f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}"
)
output_patch_len = self.config.horizon_length
num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len
for step_index in range(num_decode_patches):
current_padding = input_padding[:, 0 : final_out.shape[1]]
input_ts = final_out[:, -fcontext_len:]
input_padding = current_padding[:, -fcontext_len:]
decoder_output = self.decoder(
past_values=input_ts,
past_values_padding=input_padding,
freq=inp_freq,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
fprop_outputs = self._postprocess_output(
decoder_output.last_hidden_state,
(decoder_output.loc, decoder_output.scale),
)
if return_forecast_on_context and step_index == 0:
# For the first decodings step, collect the model forecast on the
# context except the unavailable first input batch forecast.
new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :]
# We have to use reshape and not view for non-contiguous memory
new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3))
full_outputs.append(new_full_ts)
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
# (full batch, last patch, output_patch_len, all output indices)
full_outputs.append(new_full_ts)
final_out = torch.concatenate([final_out, new_ts], axis=-1)
if return_forecast_on_context:
# `full_outputs` indexing starts at after the first input patch.
full_outputs = torch.concatenate(full_outputs, axis=1)[
:, : (context_len - self.config.patch_length + self.horizon_len), :
]
else:
# `full_outputs` indexing starts at the forecast horizon.
full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :]
mean_outputs = full_outputs[:, :, 0]
if window_size is not None:
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
if inp_min >= 0 and truncate_negative:
mean_outputs = torch.maximum(mean_outputs, 0.0)
full_outputs = torch.maximum(full_outputs, 0.0)
loss = None
if future_values is not None:
mse_loss = F.mse_loss(mean_outputs, future_values)
quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values)
loss = mse_loss + quantile_loss
return TimesFmOutputForPrediction(
last_hidden_state=decoder_output.last_hidden_state,
attentions=decoder_output.attentions if output_attentions else None,
hidden_states=decoder_output.hidden_states if output_hidden_states else None,
mean_predictions=mean_outputs,
full_predictions=full_outputs,
loss=loss,
)
@staticmethod
def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
"""Calculates the moving average using PyTorch's convolution function."""
# Pad with zeros to handle initial window positions
arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
# Create a convolution kernel
kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
# Apply convolution to calculate the moving average
smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
return [smoothed_arr, arr - smoothed_arr]
__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]

View File

@ -0,0 +1,860 @@
# coding=utf-8
# Copyright 2025 Google LLC and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch TimesFM model."""
import math
from dataclasses import dataclass
from typing import Callable, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
from ..llama.modeling_llama import LlamaRMSNorm
from ..phi4_multimodal.modeling_phi4_multimodal import simple_eager_attention_forward
from .configuration_timesfm import TimesFmConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch"
_CONFIG_FOR_DOC = "TimesFmConfig"
@dataclass
class TimesFmOutput(BaseModelOutput):
"""
Args:
loc (`torch.Tensor` of shape `(batch_size, )`):
The mean of the time series inputs.
scale (`torch.Tensor` of shape `(batch_size,)`):
The scale of the time series inputs.
"""
loc: Optional[torch.Tensor] = None
scale: Optional[torch.Tensor] = None
@dataclass
class TimesFmOutputForPrediction(BaseModelOutput):
"""
Args:
mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The mean predictions of the time series.
full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The full predictions of the time series including the mean and the quantiles.
loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
The loss of the TimesFM model.
"""
mean_predictions: Optional[torch.Tensor] = None
full_predictions: Optional[torch.Tensor] = None
loss: Optional[Union[torch.Tensor, float]] = None
class TimesFmMLP(nn.Module):
"""Pax MLP in pytorch."""
def __init__(self, config: TimesFmConfig):
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
self.down_proj = nn.Linear(intermediate_size, hidden_size)
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
def forward(self, x, paddings=None):
gate_inp = self.layer_norm(x)
gate = self.gate_proj(gate_inp)
gate = F.relu(gate)
outputs = self.down_proj(gate)
if paddings is not None:
outputs = outputs * (1.0 - paddings[:, :, None])
return outputs + x
class TimesFmResidualBlock(nn.Module):
"""TimesFM residual block."""
def __init__(self, input_dims, hidden_dims, output_dims):
super().__init__()
self.input_dims = input_dims
self.hidden_dims = hidden_dims
self.output_dims = output_dims
self.input_layer = nn.Linear(input_dims, hidden_dims)
self.activation = nn.SiLU()
self.output_layer = nn.Linear(hidden_dims, output_dims)
self.residual_layer = nn.Linear(input_dims, output_dims)
def forward(self, x):
hidden = self.input_layer(x)
hidden = self.activation(hidden)
output = self.output_layer(hidden)
residual = self.residual_layer(x)
return output + residual
class TimesFmRMSNorm(LlamaRMSNorm):
pass
class TimesFmPositionalEmbedding(nn.Module):
"""Generates position embedding for a given 1-d sequence."""
def __init__(self, config: TimesFmConfig):
super().__init__()
min_timescale = config.min_timescale
max_timescale = config.max_timescale
self.embedding_dims = config.hidden_size
num_timescales = self.embedding_dims // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
self.register_buffer(
"inv_timescales",
min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
)
def forward(self, seq_length=None, position=None):
"""Generates a Tensor of sinusoids with different frequencies.
Args:
seq_length: an optional Python int defining the output sequence length.
if the `position` argument is specified.
position: [B, seq_length], optional position for each token in the
sequence, only required when the sequence is packed.
Returns:
[B, seqlen, D] if `position` is specified, else [1, seqlen, D]
"""
if position is None and seq_length is None:
raise ValueError("Either position or seq_length must be provided")
if position is None:
# [1, seqlen]
position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
elif position.ndim != 2:
raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")
scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
# Padding to ensure correct embedding dimension
signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
return signal
class TimesFmAttention(nn.Module):
"""Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query."""
def __init__(self, config: TimesFmConfig, layer_idx: int):
super().__init__()
self.config = config
self.is_causal = True
self.attention_dropout = config.attention_dropout
self.layer_idx = layer_idx
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_heads * self.head_dim
self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
def _scale_query(self, query: torch.Tensor) -> torch.Tensor:
scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
return query * scale[None, None, None, :]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states = self._scale_query(query_states)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
attention_interface: Callable = simple_eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=1.0,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class TimesFmDecoderLayer(nn.Module):
"""Transformer layer."""
def __init__(self, config: TimesFmConfig, layer_idx: int):
super().__init__()
self.self_attn = TimesFmAttention(config, layer_idx=layer_idx)
self.mlp = TimesFmMLP(config)
self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
paddings: torch.Tensor,
output_attentions: bool = False,
) -> tuple[Optional[torch.Tensor], torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, scores = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
# MLP
hidden_states = self.mlp(hidden_states, paddings=paddings)
return scores, hidden_states
TIMESFM_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`TimesFmConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
TIMESFM_START_DOCSTRING,
)
class TimesFmPreTrainedModel(PreTrainedModel):
"""handles the loading for all models."""
config_class = TimesFmConfig
base_model_prefix = "timesfm"
_no_split_modules = ["TimesFmDecoderLayer"]
main_input_name = "past_values"
_supports_sdpa = True
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, TimesFmRMSNorm):
nn.init.zeros_(module.weight)
elif isinstance(module, TimesFmAttention):
# Initialize scaling parameter
nn.init.ones_(module.scaling)
TIMESFM_INPUTS_DOCSTRING = r"""
Args:
past_values: list of time series forecast contexts. Each context time series
can be a torch Tensor of potentially different context lengths.
freq: frequency of each context time series in the inputs. 0 for high frequency
(default), 1 for medium, and 2 for low.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
"""
@add_start_docstrings(
"The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
TIMESFM_START_DOCSTRING,
)
class TimesFmModel(TimesFmPreTrainedModel):
"""Patched time-series decoder without any specific output layer."""
def __init__(self, config: TimesFmConfig):
super().__init__(config)
self.config = config
self.input_ff_layer = TimesFmResidualBlock(
input_dims=2 * config.patch_length,
output_dims=config.hidden_size,
hidden_dims=config.intermediate_size,
)
self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size)
self.layers = nn.ModuleList(
[TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
if self.config.use_positional_embedding:
self.position_emb = TimesFmPositionalEmbedding(config=config)
# Initialize weights and apply final processing
self.post_init()
def _forward_transform(
self, inputs: torch.Tensor, patched_pads: torch.Tensor
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Input is of shape [B, N, P]."""
mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
sigma = torch.where(
sigma < self.config.tolerance,
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
sigma,
)
# Normalize each patch
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
outputs = torch.where(
torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
outputs,
)
return outputs, (mu, sigma)
@can_return_tuple
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
def forward(
self,
past_values: torch.Tensor,
past_values_padding: torch.LongTensor,
freq: torch.Tensor,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> TimesFmOutput:
"""
past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The padding indicator of the time series.
"""
# Reshape into patches (using view for efficiency)
bsize = past_values.shape[0]
patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)
patched_inputs = torch.where(
torch.abs(patched_pads - 1.0) < self.config.tolerance,
torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),
patched_inputs,
)
patched_pads = torch.where(
torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
patched_pads,
)
patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
# B x N x D
patched_inputs = patched_inputs * (1.0 - patched_pads)
concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
model_input = self.input_ff_layer(concat_inputs)
# A patch should not be padded even if there is at least one zero.
patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result
if self.config.use_positional_embedding:
pos_emb = self.position_emb(model_input.shape[1])
pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb)
model_input += pos_emb
f_emb = self.freq_emb(freq) # B x 1 x D
model_input += f_emb
# Convert paddings to attention mask and combine with causal mask
hidden_states = model_input
attention_mask = self._prepare_4d_attention_mask(
attention_mask=patched_padding,
sequence_length=hidden_states.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
is_causal=True,
)
all_attentions = []
all_hidden_states = []
for layer in self.layers[: self.config.num_hidden_layers]:
scores, hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
paddings=patched_padding,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions.append(scores)
if output_hidden_states:
all_hidden_states.append(hidden_states)
if output_hidden_states:
all_hidden_states = [model_input] + all_hidden_states
else:
all_hidden_states = None
return TimesFmOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions if output_attentions else None,
loc=stats[0],
scale=stats[1],
)
@staticmethod
def _prepare_4d_attention_mask(
attention_mask: Optional[torch.Tensor],
sequence_length: int,
dtype: torch.dtype,
device: torch.device,
is_causal: bool = True,
) -> Optional[torch.Tensor]:
"""
Creates 4D attention mask and combines causal and padding masks if needed.
Args:
attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
sequence_length: Length of the sequence
dtype: Data type of the mask
device: Device of the mask
is_causal: Whether to apply causal masking
Returns:
4D attention mask of shape (batch_size, 1, seq_length, seq_length)
"""
# Get minimum value for the dtype
min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
# Handle padding mask
if attention_mask is not None:
# Convert 2D padding mask to 4D attention mask
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
attention_mask = attention_mask * min_value
# Create causal mask if needed
if is_causal:
causal_mask = torch.triu(
torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value,
diagonal=1,
)
causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length)
# Combine with padding mask if it exists
if attention_mask is not None:
attention_mask = torch.minimum(attention_mask, causal_mask)
else:
attention_mask = causal_mask
return attention_mask
@staticmethod
def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculates mean and standard deviation of `inputs` across axis 1.
It excludes values where `padding` is 1.
Args:
inputs: A PyTorch tensor of shape [b, n, p].
padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
Returns:
A tuple containing the mean and standard deviation.
We return the statistics of the first patch with more than three non-padded values.
"""
# Selecting the first patch with more than 3 unpadded values.
def _get_patch_index(arr: torch.Tensor):
indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
pad_sum = torch.sum(1 - padding, dim=2)
patch_indices = _get_patch_index(pad_sum)
bidxs = torch.arange(inputs.shape[0])
arr = inputs[bidxs, patch_indices, :]
pad = padding[bidxs, patch_indices, :]
# Create a mask where padding is 0
mask = 1 - pad
# Calculate the number of valid elements
num_valid_elements = torch.sum(mask, dim=1)
num_valid_elements = torch.where(
num_valid_elements == 0,
torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
num_valid_elements,
)
# Calculate the masked sum and squared sum
masked_sum = torch.sum(arr * mask, dim=1)
masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1)
# Calculate the masked mean and standard deviation
masked_mean = masked_sum / num_valid_elements
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
masked_var = torch.where(
masked_var < 0.0,
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
masked_var,
)
masked_std = torch.sqrt(masked_var)
return masked_mean, masked_std
@staticmethod
def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
"""Shifts rows of seq based on the first 0 in each row of the mask.
Args:
mask: mask tensor of shape [B, N]
seq: seq tensor of shape [B, N, P]
Returns:
The shifted sequence.
"""
batch_size, num_seq, feature_dim = seq.shape
new_mask: torch.BoolTensor = mask == 0
# Use argmax to find the first True value in each row
indices = new_mask.to(torch.int32).argmax(dim=1)
# Handle rows with all zeros
indices[~new_mask.any(dim=1)] = -1
# Create index ranges for each sequence in the batch
idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim)
# Calculate shifted indices for each element in each sequence
shifted_idx = (idx_range - indices[:, None, None]) % num_seq
# Gather values from seq using shifted indices
shifted_seq = seq.gather(1, shifted_idx)
return shifted_seq
class TimesFmModelForPrediction(TimesFmPreTrainedModel):
"""TimesFM model for quantile and mean prediction."""
def __init__(self, config: TimesFmConfig):
super().__init__(config)
self.config = config
self.context_len = config.context_length
self.horizon_len = config.horizon_length
self.decoder = TimesFmModel(config)
# quantile and mean output
self.horizon_ff_layer = TimesFmResidualBlock(
input_dims=config.hidden_size,
output_dims=config.horizon_length * (1 + len(config.quantiles)),
hidden_dims=config.intermediate_size,
)
# Initialize weights and apply final processing
self.post_init()
def _preprocess(
self, inputs: Sequence[torch.Tensor], freq: Sequence[int]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Formats and pads raw inputs to feed into the model.
This function both pads each time series to match the context length, and
pads the inputs to meet the SPMD shape requirement.
Args:
inputs: A list of 1d Tensors. Each Tensor is the context time series of
a single forecast task.
freq: list of frequencies
Returns:
A tuple of:
- the padded input time series to meet the model required context.
- the padding indicator.
- the number of padded examples for SPMD so that each core has the same
number (a multiple of `batch_size`) of examples.
"""
input_ts, input_padding, inp_freq = [], [], []
for i, ts in enumerate(inputs):
input_len = ts.shape[0]
padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
if input_len < self.context_len:
num_front_pad = self.context_len - input_len
ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
elif input_len > self.context_len:
ts = ts[-self.context_len :]
padding = padding[-(self.context_len + self.horizon_len) :]
input_ts.append(ts)
input_padding.append(padding)
inp_freq.append(freq[i])
return (
torch.stack(input_ts, dim=0),
torch.stack(input_padding, dim=0),
torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
)
def _postprocess_output(
self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
"""Postprocess output of stacked transformer."""
# B x N x (H.Q)
output_ts = self.horizon_ff_layer(model_output)
# Reshape using view
b, n, _ = output_ts.shape
output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)
mu, sigma = stats
return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
losses = []
for i, q in enumerate(self.config.quantiles):
errors = targets - predictions[..., i]
loss = torch.max((q - 1) * errors, q * errors)
losses.append(loss.mean())
return torch.stack(losses).mean()
@can_return_tuple
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TimesFmOutputForPrediction,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
past_values: Sequence[torch.Tensor],
freq: Optional[Sequence[Union[torch.Tensor, int]]] = None,
window_size: Optional[int] = None,
future_values: Optional[torch.Tensor] = None,
forecast_context_len: Optional[int] = None,
return_forecast_on_context: bool = False,
truncate_negative: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> TimesFmOutputForPrediction:
r"""
window_size (`int`, *optional*):
Window size of trend + residual decomposition. If None then we do not do decomposition.
future_values (`torch.Tensor`, *optional*):
Optional future time series values to be used for loss computation.
forecast_context_len (`int`, *optional*):
Optional max context length.
return_forecast_on_context (`bool`, *optional*):
True to return the forecast on the context when available, i.e. after the first input patch.
truncate_negative (`bool`, *optional*):
Truncate to only non-negative values if any of the contexts have non-negative values,
otherwise do nothing.
output_attentions (`bool`, *optional*):
Whether to output the attentions.
output_hidden_states (`bool`, *optional*):
Whether to output the hidden states.
Returns:
A TimesFmOutputForPrediction object or a tuple containing:
- the mean forecast of size (# past_values, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# past_values, # forecast horizon, 1 + # quantiles).
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
"""
if forecast_context_len is None:
fcontext_len = self.context_len
else:
fcontext_len = forecast_context_len
# Get device from first input tensor
device = past_values[0].device
# Truncate inputs to forecast_context_len
inputs = [ts[-fcontext_len:] for ts in past_values]
inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
if window_size is not None:
new_inputs = []
new_freqs = []
for i, ts in enumerate(inputs):
new_inputs.extend(self._timesfm_moving_average(ts, window_size))
if freq is not None:
new_freqs.extend([freq[i]] * 2)
inputs = new_inputs
if freq is not None:
freq = new_freqs
if freq is None:
logger.info("No frequency provided via `freq`. Default to high (0).")
freq = [0] * len(inputs)
if output_attentions is None:
output_attentions = self.config.output_attentions
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
# Move tensors to the same device as input
input_ts = input_ts.to(device)
input_padding = input_padding.to(device)
inp_freq = inp_freq.to(device)
final_out = input_ts
context_len = final_out.shape[1]
full_outputs = []
if input_padding.shape[1] != final_out.shape[1] + self.horizon_len:
raise ValueError(
"Length of paddings must match length of input + horizon_len:"
f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}"
)
output_patch_len = self.config.horizon_length
num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len
for step_index in range(num_decode_patches):
current_padding = input_padding[:, 0 : final_out.shape[1]]
input_ts = final_out[:, -fcontext_len:]
input_padding = current_padding[:, -fcontext_len:]
decoder_output = self.decoder(
past_values=input_ts,
past_values_padding=input_padding,
freq=inp_freq,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
fprop_outputs = self._postprocess_output(
decoder_output.last_hidden_state,
(decoder_output.loc, decoder_output.scale),
)
if return_forecast_on_context and step_index == 0:
# For the first decodings step, collect the model forecast on the
# context except the unavailable first input batch forecast.
new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :]
# We have to use reshape and not view for non-contiguous memory
new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3))
full_outputs.append(new_full_ts)
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
# (full batch, last patch, output_patch_len, all output indices)
full_outputs.append(new_full_ts)
final_out = torch.concatenate([final_out, new_ts], axis=-1)
if return_forecast_on_context:
# `full_outputs` indexing starts at after the first input patch.
full_outputs = torch.concatenate(full_outputs, axis=1)[
:, : (context_len - self.config.patch_length + self.horizon_len), :
]
else:
# `full_outputs` indexing starts at the forecast horizon.
full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :]
mean_outputs = full_outputs[:, :, 0]
if window_size is not None:
mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
if inp_min >= 0 and truncate_negative:
mean_outputs = torch.maximum(mean_outputs, 0.0)
full_outputs = torch.maximum(full_outputs, 0.0)
loss = None
if future_values is not None:
mse_loss = F.mse_loss(mean_outputs, future_values)
quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values)
loss = mse_loss + quantile_loss
return TimesFmOutputForPrediction(
last_hidden_state=decoder_output.last_hidden_state,
attentions=decoder_output.attentions if output_attentions else None,
hidden_states=decoder_output.hidden_states if output_hidden_states else None,
mean_predictions=mean_outputs,
full_predictions=full_outputs,
loss=loss,
)
@staticmethod
def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
"""Calculates the moving average using PyTorch's convolution function."""
# Pad with zeros to handle initial window positions
arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
# Create a convolution kernel
kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
# Apply convolution to calculate the moving average
smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
return [smoothed_arr, arr - smoothed_arr]
__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]

View File

View File

@ -0,0 +1,197 @@
# coding=utf-8
# Copyright 2025 Google LLC and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
from typing import List
import numpy as np
import torch
from transformers import TimesFmConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_fx_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin
if is_torch_fx_available():
pass
if is_torch_available():
from transformers import TimesFmModelForPrediction
TOLERANCE = 1e-4
class TimesFmModelTester:
def __init__(
self,
parent,
patch_length: int = 32,
context_length: int = 512,
horizon_length: int = 128,
freq_size: int = 3,
num_hidden_layers: int = 1,
hidden_size: int = 16,
intermediate_size: int = 32,
head_dim: int = 8,
num_heads: int = 2,
tolerance: float = 1e-6,
rms_norm_eps: float = 1e-6,
quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
pad_val: float = 1123581321.0,
use_positional_embedding: bool = True,
initializer_factor: float = 0.0,
is_training: bool = False,
batch_size: int = 3,
):
self.parent = parent
self.patch_length = patch_length
self.context_length = context_length
self.horizon_length = horizon_length
self.quantiles = quantiles
self.pad_val = pad_val
self.freq_size = freq_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_heads
self.tolerance = tolerance
self.rms_norm_eps = rms_norm_eps
self.use_positional_embedding = use_positional_embedding
self.initializer_factor = initializer_factor
self.is_training = is_training
self.batch_size = batch_size
# The size of test input
self.seq_length = context_length // patch_length
self.hidden_size = hidden_size
def get_config(self):
return TimesFmConfig(
patch_length=self.patch_length,
context_length=self.context_length,
horizon_length=self.horizon_length,
quantiles=self.quantiles,
pad_val=self.pad_val,
freq_size=self.freq_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
head_dim=self.head_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
tolerance=self.tolerance,
rms_norm_eps=self.rms_norm_eps,
use_positional_embedding=self.use_positional_embedding,
initializer_factor=self.initializer_factor,
)
def get_pipeline_config(self):
return self.get_config()
def prepare_config_and_inputs(self):
forecast_input = [
torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
torch.tensor(np.cos(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
torch.tensor(np.tan(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
]
frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device)
return (self.get_config(), torch.stack(forecast_input, dim=0), frequency_input)
def prepare_config_and_inputs_for_common(self):
(config, forecast_input, frequency_input) = self.prepare_config_and_inputs()
inputs_dict = {
"past_values": forecast_input,
"freq": frequency_input,
}
return config, inputs_dict
@require_torch
class TimesFmModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else ()
all_generative_model_classes = ()
all_parallelizable_model_classes = ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_model_parallel = False
is_encoder_decoder = False
test_inputs_embeds = False
def setUp(self):
self.model_tester = TimesFmModelTester(self)
self.config_tester = ConfigTester(self, config_class=TimesFmConfig)
def test_create_and_run_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = TimesFmModelForPrediction(config)
model.to(torch_device)
model.eval()
results = model(**inputs_dict)
assert results.mean_predictions is not None
@unittest.skip(reason="Compile not yet supported because of masks")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Model does not have input embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Model does not have head mask")
def test_headmasking(self):
pass
# the main input name is `inputs`
def test_model_main_input_name(self):
model_signature = inspect.signature(getattr(TimesFmModelForPrediction, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name)
@require_torch
@slow
class TimesFmModelIntegrationTests(unittest.TestCase):
def test_inference_no_head(self):
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch", revision="refs/pr/7").to(
torch_device
)
forecast_input = [
np.sin(np.linspace(0, 20, 100)),
np.sin(np.linspace(0, 20, 200)),
np.sin(np.linspace(0, 20, 400)),
]
forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32, device=torch_device) for ts in forecast_input]
frequency_input = [0, 1, 2]
with torch.no_grad():
output = model(past_values=forecast_input_tensor, freq=frequency_input).last_hidden_state
self.assertEqual(
output.shape,
torch.Size([3, model.config.context_length // model.config.patch_length, model.config.hidden_size]),
)
expected_slice = torch.tensor(
[[-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932]],
device=torch_device,
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))

View File

@ -156,6 +156,7 @@ IGNORE_NON_TESTED = (
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Emu3VQVAE", # Building part of bigger (tested) model
"Emu3TextModel", # Building part of bigger (tested) model
"TimesFmModel", # Building part of bigger (tested) model
]
)