diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 91522f0fc30..0bd54e5a3b1 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1059,6 +1059,8 @@
title: PatchTST
- local: model_doc/time_series_transformer
title: Time Series Transformer
+ - local: model_doc/timesfm
+ title: TimesFM
title: Time series models
- sections:
- local: model_doc/graphormer
diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md
new file mode 100644
index 00000000000..f5e27994919
--- /dev/null
+++ b/docs/source/en/model_doc/timesfm.md
@@ -0,0 +1,88 @@
+
+
+# TimesFM
+
+
+

+
+
+## Overview
+
+TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model proposed in [A decoder-only foundation model for time-series forecasting](https://huggingface.co/papers/2310.10688) by Abhimanyu Das, Weihao Kong, Rajat Sen, and Yichen Zhou. It is a decoder only model that uses non-overlapping patches of time-series data as input and outputs some output patch length prediction in an autoregressive fashion.
+
+
+The abstract from the paper is the following:
+
+*Motivated by recent advances in large language models for Natural Language Processing (NLP), we design a time-series foundation model for forecasting whose out-of-the-box zero-shot performance on a variety of public datasets comes close to the accuracy of state-of-the-art supervised forecasting models for each individual dataset. Our model is based on pretraining a patched-decoder style attention model on a large time-series corpus, and can work well across different forecasting history lengths, prediction lengths and temporal granularities.*
+
+
+This model was contributed by [kashif](https://huggingface.co/kashif).
+The original code can be found [here](https://github.com/google-research/timesfm).
+
+
+To use the model:
+
+```python
+import torch
+from transformers import TimesFmModelForPrediction
+
+
+model = TimesFmModelForPrediction.from_pretrained(
+ "google/timesfm-2.0-500m-pytorch",
+ torch_dtype=torch.bfloat16,
+ attn_implementation="sdpa",
+ device_map="cuda" if torch.cuda.is_available() else None
+)
+
+
+ # Create dummy inputs
+forecast_input = [
+ np.sin(np.linspace(0, 20, 100)),
+ np.sin(np.linspace(0, 20, 200)),
+ np.sin(np.linspace(0, 20, 400)),
+]
+frequency_input = [0, 1, 2]
+
+# Convert inputs to sequence of tensors
+forecast_input_tensor = [
+ torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu")
+ for ts in forecast_input
+]
+frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to(
+ "cuda" if torch.cuda.is_available() else "cpu"
+)
+
+# Get predictions from the pre-trained model
+with torch.no_grad():
+ outputs = model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True)
+ point_forecast_conv = outputs.mean_predictions.float().cpu().numpy()
+ quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy()
+```
+
+## TimesFmConfig
+
+[[autodoc]] TimesFmConfig
+
+## TimesFmModel
+
+[[autodoc]] TimesFmModel
+ - forward
+
+## TimesFmModelForPrediction
+
+[[autodoc]] TimesFmModelForPrediction
+ - forward
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 73961f4a6a8..94a68374cb7 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -279,6 +279,7 @@ if TYPE_CHECKING:
from .tapas import *
from .textnet import *
from .time_series_transformer import *
+ from .timesfm import *
from .timesformer import *
from .timm_backbone import *
from .timm_wrapper import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 1fbc2cb168f..c28ec163d1c 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -313,6 +313,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("tapas", "TapasConfig"),
("textnet", "TextNetConfig"),
("time_series_transformer", "TimeSeriesTransformerConfig"),
+ ("timesfm", "TimesFmConfig"),
("timesformer", "TimesformerConfig"),
("timm_backbone", "TimmBackboneConfig"),
("timm_wrapper", "TimmWrapperConfig"),
@@ -681,6 +682,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("tapex", "TAPEX"),
("textnet", "TextNet"),
("time_series_transformer", "Time Series Transformer"),
+ ("timesfm", "TimesFm"),
("timesformer", "TimeSformer"),
("timm_backbone", "TimmBackbone"),
("timm_wrapper", "TimmWrapperModel"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index e630333602c..af832ee2393 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -281,6 +281,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("tapas", "TapasModel"),
("textnet", "TextNetModel"),
("time_series_transformer", "TimeSeriesTransformerModel"),
+ ("timesfm", "TimesFmModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
("timm_wrapper", "TimmWrapperModel"),
@@ -1542,6 +1543,12 @@ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
]
)
+MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict(
+ [
+ ("timesfm", "TimesFmModelForPrediction"),
+ ]
+)
+
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
[
("swin2sr", "Swin2SRForImageSuperResolution"),
@@ -1650,6 +1657,10 @@ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
)
+MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES
+)
+
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
@@ -1820,6 +1831,15 @@ AutoModelForSemanticSegmentation = auto_class_update(
)
+class AutoModelForTimeSeriesPrediction(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING
+
+
+AutoModelForTimeSeriesPrediction = auto_class_update(
+ AutoModelForTimeSeriesPrediction, head_doc="time-series prediction"
+)
+
+
class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
@@ -1994,6 +2014,7 @@ __all__ = [
"MODEL_FOR_TEXT_ENCODING_MAPPING",
"MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
"MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
+ "MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py
new file mode 100644
index 00000000000..12f1541b9c5
--- /dev/null
+++ b/src/transformers/models/timesfm/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_timesfm import *
+ from .modeling_timesfm import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py
new file mode 100644
index 00000000000..bd371cf1b21
--- /dev/null
+++ b/src/transformers/models/timesfm/configuration_timesfm.py
@@ -0,0 +1,129 @@
+# coding=utf-8
+# Copyright 2025 Google LLC and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TimesFM model configuration"""
+
+from typing import List
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class TimesFmConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to
+ instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the TimesFM
+ [google/timesfm-2.0-500m-pytorch](https://huggingface.co/google/timesfm-2.0-500m-pytorch) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Arguments:
+ patch_length (`int`, *optional*, defaults to 32):
+ The length of one patch in the input sequence.
+ context_length (`int`, *optional*, defaults to 512):
+ The length of the input context.
+ horizon_length (`int`, *optional*, defaults to 128):
+ The length of the prediction horizon.
+ freq_size (`int`, *optional*, defaults to 3):
+ The number of frequency embeddings.
+ num_hidden_layers (`int`, *optional*, defaults to 50):
+ Number of Transformer layers.
+ hidden_size (`int`, *optional*, defaults to 1280):
+ Size of the hidden layers in the feed-forward networks.
+ intermediate_size (`int`, *optional*, defaults to 1280):
+ Dimension of the MLP representations.
+ head_dim (`int`, *optional*, defaults to 80):
+ Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
+ be defined as `num_attention_heads * head_dim`.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ tolerance (`float`, *optional*, defaults to 1e-06):
+ The tolerance for the quantile loss.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the RMS normalization layers.
+ quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]`):
+ The quantiles to predict.
+ pad_val (`float`, *optional*, defaults to 1123581321.0):
+ The value used to pad the predictions.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for the attention scores.
+ use_positional_embedding (`bool`, *optional*, defaults to `False`):
+ Whether to add positional embeddings.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ min_timescale (`int`, *optional*, defaults to 1):
+ The start of the geometric positional index. Determines the periodicity of
+ the added signal.
+ max_timescale (`int`, *optional*, defaults to 10000):
+ The end of the geometric positional index. Determines the frequency of the
+ added signal.
+ """
+
+ model_type = "timesfm"
+ keys_to_ignore_at_inference = []
+ is_encoder_decoder = False
+
+ def __init__(
+ self,
+ patch_length: int = 32,
+ context_length: int = 512,
+ horizon_length: int = 128,
+ freq_size: int = 3,
+ num_hidden_layers: int = 50,
+ hidden_size: int = 1280,
+ intermediate_size: int = 1280,
+ head_dim: int = 80,
+ num_attention_heads: int = 16,
+ tolerance: float = 1e-6,
+ rms_norm_eps: float = 1e-6,
+ quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
+ pad_val: float = 1123581321.0,
+ attention_dropout: float = 0.0,
+ use_positional_embedding: bool = False,
+ initializer_range: float = 0.02,
+ min_timescale: int = 1,
+ max_timescale: int = 10_000,
+ **kwargs,
+ ):
+ self.patch_length = patch_length
+ self.context_length = context_length
+ self.horizon_length = horizon_length
+ self.quantiles = quantiles
+ self.pad_val = pad_val
+ self.freq_size = freq_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.head_dim = head_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.tolerance = tolerance
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_dropout = attention_dropout
+ self.use_positional_embedding = use_positional_embedding
+ self.initializer_range = initializer_range
+ self.min_timescale = min_timescale
+ self.max_timescale = max_timescale
+
+ super().__init__(
+ is_encoder_decoder=self.is_encoder_decoder,
+ **kwargs,
+ )
+
+
+__all__ = ["TimesFmConfig"]
diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py
new file mode 100644
index 00000000000..06674fb087c
--- /dev/null
+++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py
@@ -0,0 +1,275 @@
+import argparse
+import os
+import re
+import shutil
+
+import numpy as np
+import timesfm
+import torch
+
+from transformers import TimesFmConfig, TimesFmModelForPrediction
+
+
+"""
+Sample usage:
+
+```
+python src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py \
+ --output_dir /output/path
+```
+"""
+
+
+def get_nested_attr(obj, key):
+ """Recursively retrieves an attribute from an object, handling list/tuple indexing if present."""
+ parts = key.split(".")
+ for part in parts:
+ match = re.match(r"(.*)\[(\d+)\]", part) # Handle list indexing like `layers[0]`
+ if match:
+ attr_name, index = match.groups()
+ obj = getattr(obj, attr_name)[int(index)] # Access list/tuple element
+ else:
+ obj = getattr(obj, part) # Regular attribute access
+ return obj
+
+
+def write_model(model_path, safe_serialization=True, huggingface_repo_id="google/timesfm-2.0-500m-pytorch"):
+ os.makedirs(model_path, exist_ok=True)
+ tmp_model_path = os.path.join(model_path, "tmp")
+ os.makedirs(tmp_model_path, exist_ok=True)
+
+ tfm = timesfm.TimesFm(
+ hparams=timesfm.TimesFmHparams(
+ backend="cuda" if torch.cuda.is_available() else "cpu",
+ per_core_batch_size=32,
+ horizon_len=128,
+ input_patch_len=32,
+ output_patch_len=128,
+ num_layers=50,
+ model_dims=1280,
+ use_positional_embedding=False,
+ ),
+ checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
+ )
+
+ timesfm_config = TimesFmConfig(
+ patch_length=tfm.hparams.input_patch_len,
+ context_length=tfm.hparams.context_len,
+ horizon_length=tfm.hparams.horizon_len,
+ num_hidden_layers=tfm.hparams.num_layers,
+ hidden_size=tfm.hparams.model_dims,
+ intermediate_size=tfm.hparams.model_dims,
+ head_dim=tfm.hparams.model_dims // tfm.hparams.num_heads,
+ num_attention_heads=tfm.hparams.num_heads,
+ use_positional_embedding=tfm.hparams.use_positional_embedding,
+ )
+ timesfm_config.save_pretrained(tmp_model_path)
+ timesfm_model = TimesFmModelForPrediction(timesfm_config)
+
+ # copy the weights from the original model to the new model making
+ original_model = tfm._model
+
+ # mapping of the layers from the original model to the transformer model
+ MODEL_LAYER_MAPPING = {
+ "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.input_layer.weight",
+ "input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.input_layer.bias",
+ "input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight",
+ "input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias",
+ "input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight",
+ "input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias",
+ "freq_emb.weight": "decoder.freq_emb.weight",
+ "horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.input_layer.weight",
+ "horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.input_layer.bias",
+ "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight",
+ "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias",
+ "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight",
+ "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias",
+ }
+
+ TRANSFORMER_LAYER_MAPPING = {
+ "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.layers[{i}].self_attn.qkv_proj.weight",
+ "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.layers[{i}].self_attn.qkv_proj.bias",
+ "stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.layers[{i}].self_attn.o_proj.weight",
+ "stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.layers[{i}].self_attn.o_proj.bias",
+ "stacked_transformer.layers[{i}].self_attn.scaling": "decoder.layers[{i}].self_attn.scaling",
+ "stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.layers[{i}].mlp.gate_proj.weight",
+ "stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.layers[{i}].mlp.gate_proj.bias",
+ "stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.layers[{i}].mlp.down_proj.weight",
+ "stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.layers[{i}].mlp.down_proj.bias",
+ "stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.layers[{i}].mlp.layer_norm.weight",
+ "stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.layers[{i}].mlp.layer_norm.bias",
+ "stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.layers[{i}].input_layernorm.weight",
+ }
+
+ for old_key, new_key in MODEL_LAYER_MAPPING.items():
+ try:
+ old_attr = get_nested_attr(original_model, old_key) # Get tensor from original model
+ new_attr = get_nested_attr(timesfm_model, new_key) # Get corresponding attribute in new model
+ new_attr.data.copy_(old_attr.data) # Copy data
+ except AttributeError:
+ print(f"Skipping {old_key} (not found in original model).")
+
+ num_layers = len(timesfm_model.decoder.layers)
+ for i in range(num_layers):
+ for old_template, new_template in TRANSFORMER_LAYER_MAPPING.items():
+ old_key = old_template.format(i=i)
+ new_key = new_template.format(i=i)
+
+ try:
+ # Get tensor from original model
+ old_attr = get_nested_attr(original_model, old_key)
+ if "qkv_proj" in old_key:
+ # Split the tensor into q, k, v projections
+ q_proj, k_proj, v_proj = (
+ old_attr[: tfm.hparams.model_dims, ...],
+ old_attr[tfm.hparams.model_dims : tfm.hparams.model_dims * 2, ...],
+ old_attr[tfm.hparams.model_dims * 2 :, ...],
+ )
+ # Get corresponding attribute in new model
+ q_key = new_key.replace("qkv_proj", "q_proj")
+ q_attr = get_nested_attr(timesfm_model, q_key)
+ q_attr.data.copy_(q_proj.data) # Copy data
+ k_key = new_key.replace("qkv_proj", "k_proj")
+ k_attr = get_nested_attr(timesfm_model, k_key)
+ k_attr.data.copy_(k_proj.data) # Copy data
+ v_key = new_key.replace("qkv_proj", "v_proj")
+ v_attr = get_nested_attr(timesfm_model, v_key)
+ v_attr.data.copy_(v_proj.data) # Copy data
+ else:
+ # Get corresponding attribute in new model
+ new_attr = get_nested_attr(timesfm_model, new_key)
+ new_attr.data.copy_(old_attr.data) # Copy data
+ except AttributeError:
+ print(f"Skipping {old_key} (not found in original model).")
+
+ timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization)
+ shutil.rmtree(tmp_model_path)
+
+
+def check_outputs(model_path, huggingface_repo_id):
+ """Compares outputs between original and converted models."""
+ print("\nChecking model outputs...")
+
+ # Load original model
+ tfm = timesfm.TimesFm(
+ hparams=timesfm.TimesFmHparams(
+ backend="cuda" if torch.cuda.is_available() else "cpu",
+ per_core_batch_size=32,
+ horizon_len=128,
+ input_patch_len=32,
+ output_patch_len=128,
+ num_layers=50,
+ model_dims=1280,
+ use_positional_embedding=False,
+ point_forecast_mode="mean",
+ ),
+ checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
+ )
+
+ # Load converted model
+ converted_model = TimesFmModelForPrediction.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="sdpa",
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
+ converted_model.eval() # Set to evaluation mode
+
+ # Create test inputs
+ forecast_input = [
+ np.sin(np.linspace(0, 20, 100)),
+ np.sin(np.linspace(0, 20, 200)),
+ np.sin(np.linspace(0, 20, 400)),
+ ]
+ frequency_input = [0, 1, 2]
+
+ # Get predictions from original model
+ point_forecast_orig, quantile_forecast_orig = tfm.forecast(
+ forecast_input,
+ freq=frequency_input,
+ )
+
+ # Convert inputs to sequence of tensors
+ forecast_input_tensor = [
+ torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu")
+ for ts in forecast_input
+ ]
+ frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to(
+ "cuda" if torch.cuda.is_available() else "cpu"
+ )
+
+ # Get predictions from converted model
+ with torch.no_grad():
+ outputs = converted_model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True)
+ point_forecast_conv = outputs.mean_predictions.float().cpu().numpy()
+ quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy()
+
+ # Compare outputs
+ point_forecast_diff = np.abs(point_forecast_orig - point_forecast_conv)
+ quantile_forecast_diff = np.abs(quantile_forecast_orig - quantile_forecast_conv)
+
+ max_point_diff = point_forecast_diff.max()
+ mean_point_diff = point_forecast_diff.mean()
+ max_quantile_diff = quantile_forecast_diff.max()
+ mean_quantile_diff = quantile_forecast_diff.mean()
+
+ print("\nOutput comparison:")
+ print(f"Point forecast - Max difference: {max_point_diff:.6f}")
+ print(f"Point forecast - Mean difference: {mean_point_diff:.6f}")
+ print(f"Quantile forecast - Max difference: {max_quantile_diff:.6f}")
+ print(f"Quantile forecast - Mean difference: {mean_quantile_diff:.6f}")
+
+ # Define acceptable thresholds
+ POINT_THRESHOLD = 1e-5
+ QUANTILE_THRESHOLD = 1e-5
+
+ if max_point_diff > POINT_THRESHOLD or max_quantile_diff > QUANTILE_THRESHOLD:
+ raise ValueError(
+ f"Output mismatch detected!\n"
+ f"Point forecast max diff: {max_point_diff} (threshold: {POINT_THRESHOLD})\n"
+ f"Quantile forecast max diff: {max_quantile_diff} (threshold: {QUANTILE_THRESHOLD})"
+ )
+
+ print("\nâ All outputs match within acceptable tolerance!")
+
+ # Optional: Print shapes for verification
+ print("\nOutput shapes:")
+ print(f"Original point forecast: {point_forecast_orig.shape}")
+ print(f"Converted point forecast: {point_forecast_conv.shape}")
+ print(f"Original quantile forecast: {quantile_forecast_orig.shape}")
+ print(f"Converted quantile forecast: {quantile_forecast_conv.shape}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--output_dir",
+ required=True,
+ help="Location to write HF model and tokenizer",
+ )
+ parser.add_argument(
+ "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`."
+ )
+ parser.add_argument(
+ "--huggingface_repo_id",
+ type=str,
+ default="google/timesfm-2.0-500m-pytorch",
+ help="The Hugging Face repository ID to use for the model.",
+ )
+ args = parser.parse_args()
+
+ # if the saved model file exists, skip the conversion
+ if os.path.exists(
+ os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "pytorch_model.bin")
+ ):
+ print(f"Model already exists in {args.output_dir}, skipping conversion.")
+ else:
+ write_model(
+ model_path=args.output_dir,
+ safe_serialization=args.safe_serialization,
+ huggingface_repo_id=args.huggingface_repo_id,
+ )
+ check_outputs(args.output_dir, args.huggingface_repo_id)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py
new file mode 100644
index 00000000000..8a12b2c56cf
--- /dev/null
+++ b/src/transformers/models/timesfm/modeling_timesfm.py
@@ -0,0 +1,904 @@
+# đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨
+# This file was automatically generated from src/transformers/models/timesfm/modular_timesfm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_timesfm.py file directly. One of our CI enforces this.
+# đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨đ¨
+# coding=utf-8
+# Copyright 2025 Google LLC and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Callable, Optional, Sequence, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...integrations import use_kernel_forward_from_hub
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_timesfm import TimesFmConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch"
+_CONFIG_FOR_DOC = "TimesFmConfig"
+
+
+@dataclass
+class TimesFmOutput(BaseModelOutput):
+ """
+ Args:
+ loc (`torch.Tensor` of shape `(batch_size, )`):
+ The mean of the time series inputs.
+ scale (`torch.Tensor` of shape `(batch_size,)`):
+ The scale of the time series inputs.
+ """
+
+ loc: Optional[torch.Tensor] = None
+ scale: Optional[torch.Tensor] = None
+
+
+@dataclass
+class TimesFmOutputForPrediction(BaseModelOutput):
+ """
+ Args:
+ mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The mean predictions of the time series.
+ full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The full predictions of the time series including the mean and the quantiles.
+ loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
+ The loss of the TimesFM model.
+ """
+
+ mean_predictions: Optional[torch.Tensor] = None
+ full_predictions: Optional[torch.Tensor] = None
+ loss: Optional[Union[torch.Tensor, float]] = None
+
+
+class TimesFmMLP(nn.Module):
+ """Pax MLP in pytorch."""
+
+ def __init__(self, config: TimesFmConfig):
+ super().__init__()
+ hidden_size = config.hidden_size
+ intermediate_size = config.intermediate_size
+
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size)
+ self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
+
+ def forward(self, x, paddings=None):
+ gate_inp = self.layer_norm(x)
+ gate = self.gate_proj(gate_inp)
+ gate = F.relu(gate)
+ outputs = self.down_proj(gate)
+ if paddings is not None:
+ outputs = outputs * (1.0 - paddings[:, :, None])
+ return outputs + x
+
+
+class TimesFmResidualBlock(nn.Module):
+ """TimesFM residual block."""
+
+ def __init__(self, input_dims, hidden_dims, output_dims):
+ super().__init__()
+ self.input_dims = input_dims
+ self.hidden_dims = hidden_dims
+ self.output_dims = output_dims
+
+ self.input_layer = nn.Linear(input_dims, hidden_dims)
+ self.activation = nn.SiLU()
+ self.output_layer = nn.Linear(hidden_dims, output_dims)
+ self.residual_layer = nn.Linear(input_dims, output_dims)
+
+ def forward(self, x):
+ hidden = self.input_layer(x)
+ hidden = self.activation(hidden)
+ output = self.output_layer(hidden)
+ residual = self.residual_layer(x)
+ return output + residual
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class TimesFmRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ TimesFmRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class TimesFmPositionalEmbedding(nn.Module):
+ """Generates position embedding for a given 1-d sequence."""
+
+ def __init__(self, config: TimesFmConfig):
+ super().__init__()
+ min_timescale = config.min_timescale
+ max_timescale = config.max_timescale
+ self.embedding_dims = config.hidden_size
+
+ num_timescales = self.embedding_dims // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
+ self.register_buffer(
+ "inv_timescales",
+ min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
+ )
+
+ def forward(self, seq_length=None, position=None):
+ """Generates a Tensor of sinusoids with different frequencies.
+
+ Args:
+ seq_length: an optional Python int defining the output sequence length.
+ if the `position` argument is specified.
+ position: [B, seq_length], optional position for each token in the
+ sequence, only required when the sequence is packed.
+
+ Returns:
+ [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
+ """
+ if position is None and seq_length is None:
+ raise ValueError("Either position or seq_length must be provided")
+
+ if position is None:
+ # [1, seqlen]
+ position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
+ elif position.ndim != 2:
+ raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")
+
+ scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
+
+ # Padding to ensure correct embedding dimension
+ signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
+ return signal
+
+
+def simple_eager_attention_forward(
+ module: nn.Module,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class TimesFmAttention(nn.Module):
+ """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query."""
+
+ def __init__(self, config: TimesFmConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+ self.layer_idx = layer_idx
+
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_dim = config.head_dim
+
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_heads * self.head_dim
+ self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
+
+ def _scale_query(self, query: torch.Tensor) -> torch.Tensor:
+ scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
+ return query * scale[None, None, None, :]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self._scale_query(query_states)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ attention_interface: Callable = simple_eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=1.0,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class TimesFmDecoderLayer(nn.Module):
+ """Transformer layer."""
+
+ def __init__(self, config: TimesFmConfig, layer_idx: int):
+ super().__init__()
+
+ self.self_attn = TimesFmAttention(config, layer_idx=layer_idx)
+ self.mlp = TimesFmMLP(config)
+ self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ paddings: torch.Tensor,
+ output_attentions: bool = False,
+ ) -> tuple[Optional[torch.Tensor], torch.Tensor]:
+ # Self Attention
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, scores = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ # MLP
+ hidden_states = self.mlp(hidden_states, paddings=paddings)
+
+ return scores, hidden_states
+
+
+TIMESFM_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`TimesFmConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
+ TIMESFM_START_DOCSTRING,
+)
+class TimesFmPreTrainedModel(PreTrainedModel):
+ """handles the loading for all models."""
+
+ config_class = TimesFmConfig
+ base_model_prefix = "timesfm"
+ _no_split_modules = ["TimesFmDecoderLayer"]
+ main_input_name = "past_values"
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0, std=self.config.initializer_range)
+
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0, std=self.config.initializer_range)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, nn.LayerNorm):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, TimesFmRMSNorm):
+ nn.init.zeros_(module.weight)
+
+ elif isinstance(module, TimesFmAttention):
+ # Initialize scaling parameter
+ nn.init.ones_(module.scaling)
+
+
+TIMESFM_INPUTS_DOCSTRING = r"""
+ Args:
+ past_values: list of time series forecast contexts. Each context time series
+ can be a torch Tensor of potentially different context lengths.
+ freq: frequency of each context time series in the inputs. 0 for high frequency
+ (default), 1 for medium, and 2 for low.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+"""
+
+
+@add_start_docstrings(
+ "The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
+ TIMESFM_START_DOCSTRING,
+)
+class TimesFmModel(TimesFmPreTrainedModel):
+ """Patched time-series decoder without any specific output layer."""
+
+ def __init__(self, config: TimesFmConfig):
+ super().__init__(config)
+
+ self.config = config
+ self.input_ff_layer = TimesFmResidualBlock(
+ input_dims=2 * config.patch_length,
+ output_dims=config.hidden_size,
+ hidden_dims=config.intermediate_size,
+ )
+ self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size)
+ self.layers = nn.ModuleList(
+ [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ if self.config.use_positional_embedding:
+ self.position_emb = TimesFmPositionalEmbedding(config=config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _forward_transform(
+ self, inputs: torch.Tensor, patched_pads: torch.Tensor
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ """Input is of shape [B, N, P]."""
+ mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
+ sigma = torch.where(
+ sigma < self.config.tolerance,
+ torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
+ sigma,
+ )
+
+ # Normalize each patch
+ outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
+ outputs = torch.where(
+ torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
+ torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
+ outputs,
+ )
+ return outputs, (mu, sigma)
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ past_values: torch.Tensor,
+ past_values_padding: torch.LongTensor,
+ freq: torch.Tensor,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ) -> TimesFmOutput:
+ """
+ past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The padding indicator of the time series.
+ """
+ # Reshape into patches (using view for efficiency)
+ bsize = past_values.shape[0]
+ patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
+ patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)
+
+ patched_inputs = torch.where(
+ torch.abs(patched_pads - 1.0) < self.config.tolerance,
+ torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),
+ patched_inputs,
+ )
+ patched_pads = torch.where(
+ torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
+ torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
+ patched_pads,
+ )
+ patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
+
+ # B x N x D
+ patched_inputs = patched_inputs * (1.0 - patched_pads)
+ concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
+ model_input = self.input_ff_layer(concat_inputs)
+
+ # A patch should not be padded even if there is at least one zero.
+ patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result
+ if self.config.use_positional_embedding:
+ pos_emb = self.position_emb(model_input.shape[1])
+ pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
+ pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb)
+ model_input += pos_emb
+
+ f_emb = self.freq_emb(freq) # B x 1 x D
+ model_input += f_emb
+
+ # Convert paddings to attention mask and combine with causal mask
+ hidden_states = model_input
+ attention_mask = self._prepare_4d_attention_mask(
+ attention_mask=patched_padding,
+ sequence_length=hidden_states.shape[1],
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ is_causal=True,
+ )
+
+ all_attentions = []
+ all_hidden_states = []
+
+ for layer in self.layers[: self.config.num_hidden_layers]:
+ scores, hidden_states = layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ paddings=patched_padding,
+ output_attentions=output_attentions,
+ )
+ if output_attentions:
+ all_attentions.append(scores)
+ if output_hidden_states:
+ all_hidden_states.append(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = [model_input] + all_hidden_states
+ else:
+ all_hidden_states = None
+
+ return TimesFmOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions if output_attentions else None,
+ loc=stats[0],
+ scale=stats[1],
+ )
+
+ @staticmethod
+ def _prepare_4d_attention_mask(
+ attention_mask: Optional[torch.Tensor],
+ sequence_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ is_causal: bool = True,
+ ) -> Optional[torch.Tensor]:
+ """
+ Creates 4D attention mask and combines causal and padding masks if needed.
+
+ Args:
+ attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
+ sequence_length: Length of the sequence
+ dtype: Data type of the mask
+ device: Device of the mask
+ is_causal: Whether to apply causal masking
+
+ Returns:
+ 4D attention mask of shape (batch_size, 1, seq_length, seq_length)
+ """
+ # Get minimum value for the dtype
+ min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
+
+ # Handle padding mask
+ if attention_mask is not None:
+ # Convert 2D padding mask to 4D attention mask
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
+ attention_mask = attention_mask * min_value
+
+ # Create causal mask if needed
+ if is_causal:
+ causal_mask = torch.triu(
+ torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value,
+ diagonal=1,
+ )
+ causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length)
+
+ # Combine with padding mask if it exists
+ if attention_mask is not None:
+ attention_mask = torch.minimum(attention_mask, causal_mask)
+ else:
+ attention_mask = causal_mask
+
+ return attention_mask
+
+ @staticmethod
+ def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """Calculates mean and standard deviation of `inputs` across axis 1.
+
+ It excludes values where `padding` is 1.
+
+ Args:
+ inputs: A PyTorch tensor of shape [b, n, p].
+ padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
+
+ Returns:
+ A tuple containing the mean and standard deviation.
+ We return the statistics of the first patch with more than three non-padded values.
+ """
+
+ # Selecting the first patch with more than 3 unpadded values.
+ def _get_patch_index(arr: torch.Tensor):
+ indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
+ row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
+ return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
+
+ pad_sum = torch.sum(1 - padding, dim=2)
+ patch_indices = _get_patch_index(pad_sum)
+ bidxs = torch.arange(inputs.shape[0])
+
+ arr = inputs[bidxs, patch_indices, :]
+ pad = padding[bidxs, patch_indices, :]
+
+ # Create a mask where padding is 0
+ mask = 1 - pad
+
+ # Calculate the number of valid elements
+ num_valid_elements = torch.sum(mask, dim=1)
+ num_valid_elements = torch.where(
+ num_valid_elements == 0,
+ torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
+ num_valid_elements,
+ )
+
+ # Calculate the masked sum and squared sum
+ masked_sum = torch.sum(arr * mask, dim=1)
+ masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1)
+
+ # Calculate the masked mean and standard deviation
+ masked_mean = masked_sum / num_valid_elements
+ masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
+ masked_var = torch.where(
+ masked_var < 0.0,
+ torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
+ masked_var,
+ )
+ masked_std = torch.sqrt(masked_var)
+
+ return masked_mean, masked_std
+
+ @staticmethod
+ def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
+ """Shifts rows of seq based on the first 0 in each row of the mask.
+
+ Args:
+ mask: mask tensor of shape [B, N]
+ seq: seq tensor of shape [B, N, P]
+
+ Returns:
+ The shifted sequence.
+ """
+ batch_size, num_seq, feature_dim = seq.shape
+
+ new_mask: torch.BoolTensor = mask == 0
+
+ # Use argmax to find the first True value in each row
+ indices = new_mask.to(torch.int32).argmax(dim=1)
+
+ # Handle rows with all zeros
+ indices[~new_mask.any(dim=1)] = -1
+
+ # Create index ranges for each sequence in the batch
+ idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim)
+
+ # Calculate shifted indices for each element in each sequence
+ shifted_idx = (idx_range - indices[:, None, None]) % num_seq
+
+ # Gather values from seq using shifted indices
+ shifted_seq = seq.gather(1, shifted_idx)
+
+ return shifted_seq
+
+
+class TimesFmModelForPrediction(TimesFmPreTrainedModel):
+ """TimesFM model for quantile and mean prediction."""
+
+ def __init__(self, config: TimesFmConfig):
+ super().__init__(config)
+
+ self.config = config
+ self.context_len = config.context_length
+ self.horizon_len = config.horizon_length
+
+ self.decoder = TimesFmModel(config)
+
+ # quantile and mean output
+ self.horizon_ff_layer = TimesFmResidualBlock(
+ input_dims=config.hidden_size,
+ output_dims=config.horizon_length * (1 + len(config.quantiles)),
+ hidden_dims=config.intermediate_size,
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _preprocess(
+ self, inputs: Sequence[torch.Tensor], freq: Sequence[int]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Formats and pads raw inputs to feed into the model.
+
+ This function both pads each time series to match the context length, and
+ pads the inputs to meet the SPMD shape requirement.
+
+ Args:
+ inputs: A list of 1d Tensors. Each Tensor is the context time series of
+ a single forecast task.
+ freq: list of frequencies
+
+ Returns:
+ A tuple of:
+ - the padded input time series to meet the model required context.
+ - the padding indicator.
+ - the number of padded examples for SPMD so that each core has the same
+ number (a multiple of `batch_size`) of examples.
+ """
+ input_ts, input_padding, inp_freq = [], [], []
+
+ for i, ts in enumerate(inputs):
+ input_len = ts.shape[0]
+ padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
+ if input_len < self.context_len:
+ num_front_pad = self.context_len - input_len
+ ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
+ padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
+ elif input_len > self.context_len:
+ ts = ts[-self.context_len :]
+ padding = padding[-(self.context_len + self.horizon_len) :]
+
+ input_ts.append(ts)
+ input_padding.append(padding)
+ inp_freq.append(freq[i])
+
+ return (
+ torch.stack(input_ts, dim=0),
+ torch.stack(input_padding, dim=0),
+ torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
+ )
+
+ def _postprocess_output(
+ self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
+ ) -> torch.Tensor:
+ """Postprocess output of stacked transformer."""
+
+ # B x N x (H.Q)
+ output_ts = self.horizon_ff_layer(model_output)
+
+ # Reshape using view
+ b, n, _ = output_ts.shape
+ output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)
+
+ mu, sigma = stats
+ return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
+
+ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ losses = []
+ for i, q in enumerate(self.config.quantiles):
+ errors = targets - predictions[..., i]
+ loss = torch.max((q - 1) * errors, q * errors)
+ losses.append(loss.mean())
+ return torch.stack(losses).mean()
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TimesFmOutputForPrediction,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ past_values: Sequence[torch.Tensor],
+ freq: Optional[Sequence[Union[torch.Tensor, int]]] = None,
+ window_size: Optional[int] = None,
+ future_values: Optional[torch.Tensor] = None,
+ forecast_context_len: Optional[int] = None,
+ return_forecast_on_context: bool = False,
+ truncate_negative: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> TimesFmOutputForPrediction:
+ r"""
+ window_size (`int`, *optional*):
+ Window size of trend + residual decomposition. If None then we do not do decomposition.
+ future_values (`torch.Tensor`, *optional*):
+ Optional future time series values to be used for loss computation.
+ forecast_context_len (`int`, *optional*):
+ Optional max context length.
+ return_forecast_on_context (`bool`, *optional*):
+ True to return the forecast on the context when available, i.e. after the first input patch.
+ truncate_negative (`bool`, *optional*):
+ Truncate to only non-negative values if any of the contexts have non-negative values,
+ otherwise do nothing.
+ output_attentions (`bool`, *optional*):
+ Whether to output the attentions.
+ output_hidden_states (`bool`, *optional*):
+ Whether to output the hidden states.
+
+ Returns:
+ A TimesFmOutputForPrediction object or a tuple containing:
+ - the mean forecast of size (# past_values, # forecast horizon),
+ - the full forecast (mean + quantiles) of size
+ (# past_values, # forecast horizon, 1 + # quantiles).
+ - loss: the mean squared error loss + quantile loss if `future_values` is provided.
+ """
+ if forecast_context_len is None:
+ fcontext_len = self.context_len
+ else:
+ fcontext_len = forecast_context_len
+
+ # Get device from first input tensor
+ device = past_values[0].device
+
+ # Truncate inputs to forecast_context_len
+ inputs = [ts[-fcontext_len:] for ts in past_values]
+ inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
+
+ if window_size is not None:
+ new_inputs = []
+ new_freqs = []
+ for i, ts in enumerate(inputs):
+ new_inputs.extend(self._timesfm_moving_average(ts, window_size))
+ if freq is not None:
+ new_freqs.extend([freq[i]] * 2)
+ inputs = new_inputs
+ if freq is not None:
+ freq = new_freqs
+
+ if freq is None:
+ logger.info("No frequency provided via `freq`. Default to high (0).")
+ freq = [0] * len(inputs)
+
+ if output_attentions is None:
+ output_attentions = self.config.output_attentions
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
+ # Move tensors to the same device as input
+ input_ts = input_ts.to(device)
+ input_padding = input_padding.to(device)
+ inp_freq = inp_freq.to(device)
+
+ final_out = input_ts
+ context_len = final_out.shape[1]
+ full_outputs = []
+
+ if input_padding.shape[1] != final_out.shape[1] + self.horizon_len:
+ raise ValueError(
+ "Length of paddings must match length of input + horizon_len:"
+ f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}"
+ )
+ output_patch_len = self.config.horizon_length
+
+ num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len
+ for step_index in range(num_decode_patches):
+ current_padding = input_padding[:, 0 : final_out.shape[1]]
+ input_ts = final_out[:, -fcontext_len:]
+ input_padding = current_padding[:, -fcontext_len:]
+ decoder_output = self.decoder(
+ past_values=input_ts,
+ past_values_padding=input_padding,
+ freq=inp_freq,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ fprop_outputs = self._postprocess_output(
+ decoder_output.last_hidden_state,
+ (decoder_output.loc, decoder_output.scale),
+ )
+
+ if return_forecast_on_context and step_index == 0:
+ # For the first decodings step, collect the model forecast on the
+ # context except the unavailable first input batch forecast.
+ new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :]
+ # We have to use reshape and not view for non-contiguous memory
+ new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3))
+
+ full_outputs.append(new_full_ts)
+
+ # (full batch, last patch, output_patch_len, index of mean forecast = 0)
+ new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
+ new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
+ # (full batch, last patch, output_patch_len, all output indices)
+ full_outputs.append(new_full_ts)
+ final_out = torch.concatenate([final_out, new_ts], axis=-1)
+
+ if return_forecast_on_context:
+ # `full_outputs` indexing starts at after the first input patch.
+ full_outputs = torch.concatenate(full_outputs, axis=1)[
+ :, : (context_len - self.config.patch_length + self.horizon_len), :
+ ]
+ else:
+ # `full_outputs` indexing starts at the forecast horizon.
+ full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :]
+
+ mean_outputs = full_outputs[:, :, 0]
+ if window_size is not None:
+ mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
+ full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
+ if inp_min >= 0 and truncate_negative:
+ mean_outputs = torch.maximum(mean_outputs, 0.0)
+ full_outputs = torch.maximum(full_outputs, 0.0)
+
+ loss = None
+ if future_values is not None:
+ mse_loss = F.mse_loss(mean_outputs, future_values)
+ quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values)
+ loss = mse_loss + quantile_loss
+
+ return TimesFmOutputForPrediction(
+ last_hidden_state=decoder_output.last_hidden_state,
+ attentions=decoder_output.attentions if output_attentions else None,
+ hidden_states=decoder_output.hidden_states if output_hidden_states else None,
+ mean_predictions=mean_outputs,
+ full_predictions=full_outputs,
+ loss=loss,
+ )
+
+ @staticmethod
+ def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
+ """Calculates the moving average using PyTorch's convolution function."""
+ # Pad with zeros to handle initial window positions
+ arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
+ # Create a convolution kernel
+ kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
+ # Apply convolution to calculate the moving average
+ smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
+ return [smoothed_arr, arr - smoothed_arr]
+
+
+__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]
diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py
new file mode 100644
index 00000000000..4a627524849
--- /dev/null
+++ b/src/transformers/models/timesfm/modular_timesfm.py
@@ -0,0 +1,860 @@
+# coding=utf-8
+# Copyright 2025 Google LLC and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch TimesFM model."""
+
+import math
+from dataclasses import dataclass
+from typing import Callable, Optional, Sequence, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ logging,
+ replace_return_docstrings,
+)
+from ..llama.modeling_llama import LlamaRMSNorm
+from ..phi4_multimodal.modeling_phi4_multimodal import simple_eager_attention_forward
+from .configuration_timesfm import TimesFmConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch"
+_CONFIG_FOR_DOC = "TimesFmConfig"
+
+
+@dataclass
+class TimesFmOutput(BaseModelOutput):
+ """
+ Args:
+ loc (`torch.Tensor` of shape `(batch_size, )`):
+ The mean of the time series inputs.
+ scale (`torch.Tensor` of shape `(batch_size,)`):
+ The scale of the time series inputs.
+ """
+
+ loc: Optional[torch.Tensor] = None
+ scale: Optional[torch.Tensor] = None
+
+
+@dataclass
+class TimesFmOutputForPrediction(BaseModelOutput):
+ """
+ Args:
+ mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The mean predictions of the time series.
+ full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The full predictions of the time series including the mean and the quantiles.
+ loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
+ The loss of the TimesFM model.
+ """
+
+ mean_predictions: Optional[torch.Tensor] = None
+ full_predictions: Optional[torch.Tensor] = None
+ loss: Optional[Union[torch.Tensor, float]] = None
+
+
+class TimesFmMLP(nn.Module):
+ """Pax MLP in pytorch."""
+
+ def __init__(self, config: TimesFmConfig):
+ super().__init__()
+ hidden_size = config.hidden_size
+ intermediate_size = config.intermediate_size
+
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size)
+ self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
+
+ def forward(self, x, paddings=None):
+ gate_inp = self.layer_norm(x)
+ gate = self.gate_proj(gate_inp)
+ gate = F.relu(gate)
+ outputs = self.down_proj(gate)
+ if paddings is not None:
+ outputs = outputs * (1.0 - paddings[:, :, None])
+ return outputs + x
+
+
+class TimesFmResidualBlock(nn.Module):
+ """TimesFM residual block."""
+
+ def __init__(self, input_dims, hidden_dims, output_dims):
+ super().__init__()
+ self.input_dims = input_dims
+ self.hidden_dims = hidden_dims
+ self.output_dims = output_dims
+
+ self.input_layer = nn.Linear(input_dims, hidden_dims)
+ self.activation = nn.SiLU()
+ self.output_layer = nn.Linear(hidden_dims, output_dims)
+ self.residual_layer = nn.Linear(input_dims, output_dims)
+
+ def forward(self, x):
+ hidden = self.input_layer(x)
+ hidden = self.activation(hidden)
+ output = self.output_layer(hidden)
+ residual = self.residual_layer(x)
+ return output + residual
+
+
+class TimesFmRMSNorm(LlamaRMSNorm):
+ pass
+
+
+class TimesFmPositionalEmbedding(nn.Module):
+ """Generates position embedding for a given 1-d sequence."""
+
+ def __init__(self, config: TimesFmConfig):
+ super().__init__()
+ min_timescale = config.min_timescale
+ max_timescale = config.max_timescale
+ self.embedding_dims = config.hidden_size
+
+ num_timescales = self.embedding_dims // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
+ self.register_buffer(
+ "inv_timescales",
+ min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
+ )
+
+ def forward(self, seq_length=None, position=None):
+ """Generates a Tensor of sinusoids with different frequencies.
+
+ Args:
+ seq_length: an optional Python int defining the output sequence length.
+ if the `position` argument is specified.
+ position: [B, seq_length], optional position for each token in the
+ sequence, only required when the sequence is packed.
+
+ Returns:
+ [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
+ """
+ if position is None and seq_length is None:
+ raise ValueError("Either position or seq_length must be provided")
+
+ if position is None:
+ # [1, seqlen]
+ position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
+ elif position.ndim != 2:
+ raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")
+
+ scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
+
+ # Padding to ensure correct embedding dimension
+ signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
+ return signal
+
+
+class TimesFmAttention(nn.Module):
+ """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query."""
+
+ def __init__(self, config: TimesFmConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+ self.layer_idx = layer_idx
+
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_dim = config.head_dim
+
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_heads * self.head_dim
+ self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
+
+ def _scale_query(self, query: torch.Tensor) -> torch.Tensor:
+ scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
+ return query * scale[None, None, None, :]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self._scale_query(query_states)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ attention_interface: Callable = simple_eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=1.0,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class TimesFmDecoderLayer(nn.Module):
+ """Transformer layer."""
+
+ def __init__(self, config: TimesFmConfig, layer_idx: int):
+ super().__init__()
+
+ self.self_attn = TimesFmAttention(config, layer_idx=layer_idx)
+ self.mlp = TimesFmMLP(config)
+ self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ paddings: torch.Tensor,
+ output_attentions: bool = False,
+ ) -> tuple[Optional[torch.Tensor], torch.Tensor]:
+ # Self Attention
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, scores = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ # MLP
+ hidden_states = self.mlp(hidden_states, paddings=paddings)
+
+ return scores, hidden_states
+
+
+TIMESFM_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`TimesFmConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
+ TIMESFM_START_DOCSTRING,
+)
+class TimesFmPreTrainedModel(PreTrainedModel):
+ """handles the loading for all models."""
+
+ config_class = TimesFmConfig
+ base_model_prefix = "timesfm"
+ _no_split_modules = ["TimesFmDecoderLayer"]
+ main_input_name = "past_values"
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0, std=self.config.initializer_range)
+
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0, std=self.config.initializer_range)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, nn.LayerNorm):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, TimesFmRMSNorm):
+ nn.init.zeros_(module.weight)
+
+ elif isinstance(module, TimesFmAttention):
+ # Initialize scaling parameter
+ nn.init.ones_(module.scaling)
+
+
+TIMESFM_INPUTS_DOCSTRING = r"""
+ Args:
+ past_values: list of time series forecast contexts. Each context time series
+ can be a torch Tensor of potentially different context lengths.
+ freq: frequency of each context time series in the inputs. 0 for high frequency
+ (default), 1 for medium, and 2 for low.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+"""
+
+
+@add_start_docstrings(
+ "The bare TimesFM Model outputting raw hidden-states without any specific head on top.",
+ TIMESFM_START_DOCSTRING,
+)
+class TimesFmModel(TimesFmPreTrainedModel):
+ """Patched time-series decoder without any specific output layer."""
+
+ def __init__(self, config: TimesFmConfig):
+ super().__init__(config)
+
+ self.config = config
+ self.input_ff_layer = TimesFmResidualBlock(
+ input_dims=2 * config.patch_length,
+ output_dims=config.hidden_size,
+ hidden_dims=config.intermediate_size,
+ )
+ self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size)
+ self.layers = nn.ModuleList(
+ [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ if self.config.use_positional_embedding:
+ self.position_emb = TimesFmPositionalEmbedding(config=config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _forward_transform(
+ self, inputs: torch.Tensor, patched_pads: torch.Tensor
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ """Input is of shape [B, N, P]."""
+ mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
+ sigma = torch.where(
+ sigma < self.config.tolerance,
+ torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
+ sigma,
+ )
+
+ # Normalize each patch
+ outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
+ outputs = torch.where(
+ torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
+ torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
+ outputs,
+ )
+ return outputs, (mu, sigma)
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ past_values: torch.Tensor,
+ past_values_padding: torch.LongTensor,
+ freq: torch.Tensor,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ) -> TimesFmOutput:
+ """
+ past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The padding indicator of the time series.
+ """
+ # Reshape into patches (using view for efficiency)
+ bsize = past_values.shape[0]
+ patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
+ patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)
+
+ patched_inputs = torch.where(
+ torch.abs(patched_pads - 1.0) < self.config.tolerance,
+ torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),
+ patched_inputs,
+ )
+ patched_pads = torch.where(
+ torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
+ torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
+ patched_pads,
+ )
+ patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
+
+ # B x N x D
+ patched_inputs = patched_inputs * (1.0 - patched_pads)
+ concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
+ model_input = self.input_ff_layer(concat_inputs)
+
+ # A patch should not be padded even if there is at least one zero.
+ patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result
+ if self.config.use_positional_embedding:
+ pos_emb = self.position_emb(model_input.shape[1])
+ pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
+ pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb)
+ model_input += pos_emb
+
+ f_emb = self.freq_emb(freq) # B x 1 x D
+ model_input += f_emb
+
+ # Convert paddings to attention mask and combine with causal mask
+ hidden_states = model_input
+ attention_mask = self._prepare_4d_attention_mask(
+ attention_mask=patched_padding,
+ sequence_length=hidden_states.shape[1],
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ is_causal=True,
+ )
+
+ all_attentions = []
+ all_hidden_states = []
+
+ for layer in self.layers[: self.config.num_hidden_layers]:
+ scores, hidden_states = layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ paddings=patched_padding,
+ output_attentions=output_attentions,
+ )
+ if output_attentions:
+ all_attentions.append(scores)
+ if output_hidden_states:
+ all_hidden_states.append(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = [model_input] + all_hidden_states
+ else:
+ all_hidden_states = None
+
+ return TimesFmOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions if output_attentions else None,
+ loc=stats[0],
+ scale=stats[1],
+ )
+
+ @staticmethod
+ def _prepare_4d_attention_mask(
+ attention_mask: Optional[torch.Tensor],
+ sequence_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ is_causal: bool = True,
+ ) -> Optional[torch.Tensor]:
+ """
+ Creates 4D attention mask and combines causal and padding masks if needed.
+
+ Args:
+ attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
+ sequence_length: Length of the sequence
+ dtype: Data type of the mask
+ device: Device of the mask
+ is_causal: Whether to apply causal masking
+
+ Returns:
+ 4D attention mask of shape (batch_size, 1, seq_length, seq_length)
+ """
+ # Get minimum value for the dtype
+ min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
+
+ # Handle padding mask
+ if attention_mask is not None:
+ # Convert 2D padding mask to 4D attention mask
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
+ attention_mask = attention_mask * min_value
+
+ # Create causal mask if needed
+ if is_causal:
+ causal_mask = torch.triu(
+ torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value,
+ diagonal=1,
+ )
+ causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length)
+
+ # Combine with padding mask if it exists
+ if attention_mask is not None:
+ attention_mask = torch.minimum(attention_mask, causal_mask)
+ else:
+ attention_mask = causal_mask
+
+ return attention_mask
+
+ @staticmethod
+ def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """Calculates mean and standard deviation of `inputs` across axis 1.
+
+ It excludes values where `padding` is 1.
+
+ Args:
+ inputs: A PyTorch tensor of shape [b, n, p].
+ padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
+
+ Returns:
+ A tuple containing the mean and standard deviation.
+ We return the statistics of the first patch with more than three non-padded values.
+ """
+
+ # Selecting the first patch with more than 3 unpadded values.
+ def _get_patch_index(arr: torch.Tensor):
+ indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
+ row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
+ return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
+
+ pad_sum = torch.sum(1 - padding, dim=2)
+ patch_indices = _get_patch_index(pad_sum)
+ bidxs = torch.arange(inputs.shape[0])
+
+ arr = inputs[bidxs, patch_indices, :]
+ pad = padding[bidxs, patch_indices, :]
+
+ # Create a mask where padding is 0
+ mask = 1 - pad
+
+ # Calculate the number of valid elements
+ num_valid_elements = torch.sum(mask, dim=1)
+ num_valid_elements = torch.where(
+ num_valid_elements == 0,
+ torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device),
+ num_valid_elements,
+ )
+
+ # Calculate the masked sum and squared sum
+ masked_sum = torch.sum(arr * mask, dim=1)
+ masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1)
+
+ # Calculate the masked mean and standard deviation
+ masked_mean = masked_sum / num_valid_elements
+ masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
+ masked_var = torch.where(
+ masked_var < 0.0,
+ torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
+ masked_var,
+ )
+ masked_std = torch.sqrt(masked_var)
+
+ return masked_mean, masked_std
+
+ @staticmethod
+ def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
+ """Shifts rows of seq based on the first 0 in each row of the mask.
+
+ Args:
+ mask: mask tensor of shape [B, N]
+ seq: seq tensor of shape [B, N, P]
+
+ Returns:
+ The shifted sequence.
+ """
+ batch_size, num_seq, feature_dim = seq.shape
+
+ new_mask: torch.BoolTensor = mask == 0
+
+ # Use argmax to find the first True value in each row
+ indices = new_mask.to(torch.int32).argmax(dim=1)
+
+ # Handle rows with all zeros
+ indices[~new_mask.any(dim=1)] = -1
+
+ # Create index ranges for each sequence in the batch
+ idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim)
+
+ # Calculate shifted indices for each element in each sequence
+ shifted_idx = (idx_range - indices[:, None, None]) % num_seq
+
+ # Gather values from seq using shifted indices
+ shifted_seq = seq.gather(1, shifted_idx)
+
+ return shifted_seq
+
+
+class TimesFmModelForPrediction(TimesFmPreTrainedModel):
+ """TimesFM model for quantile and mean prediction."""
+
+ def __init__(self, config: TimesFmConfig):
+ super().__init__(config)
+
+ self.config = config
+ self.context_len = config.context_length
+ self.horizon_len = config.horizon_length
+
+ self.decoder = TimesFmModel(config)
+
+ # quantile and mean output
+ self.horizon_ff_layer = TimesFmResidualBlock(
+ input_dims=config.hidden_size,
+ output_dims=config.horizon_length * (1 + len(config.quantiles)),
+ hidden_dims=config.intermediate_size,
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _preprocess(
+ self, inputs: Sequence[torch.Tensor], freq: Sequence[int]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Formats and pads raw inputs to feed into the model.
+
+ This function both pads each time series to match the context length, and
+ pads the inputs to meet the SPMD shape requirement.
+
+ Args:
+ inputs: A list of 1d Tensors. Each Tensor is the context time series of
+ a single forecast task.
+ freq: list of frequencies
+
+ Returns:
+ A tuple of:
+ - the padded input time series to meet the model required context.
+ - the padding indicator.
+ - the number of padded examples for SPMD so that each core has the same
+ number (a multiple of `batch_size`) of examples.
+ """
+ input_ts, input_padding, inp_freq = [], [], []
+
+ for i, ts in enumerate(inputs):
+ input_len = ts.shape[0]
+ padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
+ if input_len < self.context_len:
+ num_front_pad = self.context_len - input_len
+ ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
+ padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
+ elif input_len > self.context_len:
+ ts = ts[-self.context_len :]
+ padding = padding[-(self.context_len + self.horizon_len) :]
+
+ input_ts.append(ts)
+ input_padding.append(padding)
+ inp_freq.append(freq[i])
+
+ return (
+ torch.stack(input_ts, dim=0),
+ torch.stack(input_padding, dim=0),
+ torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
+ )
+
+ def _postprocess_output(
+ self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
+ ) -> torch.Tensor:
+ """Postprocess output of stacked transformer."""
+
+ # B x N x (H.Q)
+ output_ts = self.horizon_ff_layer(model_output)
+
+ # Reshape using view
+ b, n, _ = output_ts.shape
+ output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)
+
+ mu, sigma = stats
+ return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
+
+ def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ losses = []
+ for i, q in enumerate(self.config.quantiles):
+ errors = targets - predictions[..., i]
+ loss = torch.max((q - 1) * errors, q * errors)
+ losses.append(loss.mean())
+ return torch.stack(losses).mean()
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TimesFmOutputForPrediction,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ past_values: Sequence[torch.Tensor],
+ freq: Optional[Sequence[Union[torch.Tensor, int]]] = None,
+ window_size: Optional[int] = None,
+ future_values: Optional[torch.Tensor] = None,
+ forecast_context_len: Optional[int] = None,
+ return_forecast_on_context: bool = False,
+ truncate_negative: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> TimesFmOutputForPrediction:
+ r"""
+ window_size (`int`, *optional*):
+ Window size of trend + residual decomposition. If None then we do not do decomposition.
+ future_values (`torch.Tensor`, *optional*):
+ Optional future time series values to be used for loss computation.
+ forecast_context_len (`int`, *optional*):
+ Optional max context length.
+ return_forecast_on_context (`bool`, *optional*):
+ True to return the forecast on the context when available, i.e. after the first input patch.
+ truncate_negative (`bool`, *optional*):
+ Truncate to only non-negative values if any of the contexts have non-negative values,
+ otherwise do nothing.
+ output_attentions (`bool`, *optional*):
+ Whether to output the attentions.
+ output_hidden_states (`bool`, *optional*):
+ Whether to output the hidden states.
+
+ Returns:
+ A TimesFmOutputForPrediction object or a tuple containing:
+ - the mean forecast of size (# past_values, # forecast horizon),
+ - the full forecast (mean + quantiles) of size
+ (# past_values, # forecast horizon, 1 + # quantiles).
+ - loss: the mean squared error loss + quantile loss if `future_values` is provided.
+ """
+ if forecast_context_len is None:
+ fcontext_len = self.context_len
+ else:
+ fcontext_len = forecast_context_len
+
+ # Get device from first input tensor
+ device = past_values[0].device
+
+ # Truncate inputs to forecast_context_len
+ inputs = [ts[-fcontext_len:] for ts in past_values]
+ inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
+
+ if window_size is not None:
+ new_inputs = []
+ new_freqs = []
+ for i, ts in enumerate(inputs):
+ new_inputs.extend(self._timesfm_moving_average(ts, window_size))
+ if freq is not None:
+ new_freqs.extend([freq[i]] * 2)
+ inputs = new_inputs
+ if freq is not None:
+ freq = new_freqs
+
+ if freq is None:
+ logger.info("No frequency provided via `freq`. Default to high (0).")
+ freq = [0] * len(inputs)
+
+ if output_attentions is None:
+ output_attentions = self.config.output_attentions
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+
+ input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
+ # Move tensors to the same device as input
+ input_ts = input_ts.to(device)
+ input_padding = input_padding.to(device)
+ inp_freq = inp_freq.to(device)
+
+ final_out = input_ts
+ context_len = final_out.shape[1]
+ full_outputs = []
+
+ if input_padding.shape[1] != final_out.shape[1] + self.horizon_len:
+ raise ValueError(
+ "Length of paddings must match length of input + horizon_len:"
+ f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}"
+ )
+ output_patch_len = self.config.horizon_length
+
+ num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len
+ for step_index in range(num_decode_patches):
+ current_padding = input_padding[:, 0 : final_out.shape[1]]
+ input_ts = final_out[:, -fcontext_len:]
+ input_padding = current_padding[:, -fcontext_len:]
+ decoder_output = self.decoder(
+ past_values=input_ts,
+ past_values_padding=input_padding,
+ freq=inp_freq,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ fprop_outputs = self._postprocess_output(
+ decoder_output.last_hidden_state,
+ (decoder_output.loc, decoder_output.scale),
+ )
+
+ if return_forecast_on_context and step_index == 0:
+ # For the first decodings step, collect the model forecast on the
+ # context except the unavailable first input batch forecast.
+ new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :]
+ # We have to use reshape and not view for non-contiguous memory
+ new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3))
+
+ full_outputs.append(new_full_ts)
+
+ # (full batch, last patch, output_patch_len, index of mean forecast = 0)
+ new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
+ new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
+ # (full batch, last patch, output_patch_len, all output indices)
+ full_outputs.append(new_full_ts)
+ final_out = torch.concatenate([final_out, new_ts], axis=-1)
+
+ if return_forecast_on_context:
+ # `full_outputs` indexing starts at after the first input patch.
+ full_outputs = torch.concatenate(full_outputs, axis=1)[
+ :, : (context_len - self.config.patch_length + self.horizon_len), :
+ ]
+ else:
+ # `full_outputs` indexing starts at the forecast horizon.
+ full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :]
+
+ mean_outputs = full_outputs[:, :, 0]
+ if window_size is not None:
+ mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
+ full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
+ if inp_min >= 0 and truncate_negative:
+ mean_outputs = torch.maximum(mean_outputs, 0.0)
+ full_outputs = torch.maximum(full_outputs, 0.0)
+
+ loss = None
+ if future_values is not None:
+ mse_loss = F.mse_loss(mean_outputs, future_values)
+ quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values)
+ loss = mse_loss + quantile_loss
+
+ return TimesFmOutputForPrediction(
+ last_hidden_state=decoder_output.last_hidden_state,
+ attentions=decoder_output.attentions if output_attentions else None,
+ hidden_states=decoder_output.hidden_states if output_hidden_states else None,
+ mean_predictions=mean_outputs,
+ full_predictions=full_outputs,
+ loss=loss,
+ )
+
+ @staticmethod
+ def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
+ """Calculates the moving average using PyTorch's convolution function."""
+ # Pad with zeros to handle initial window positions
+ arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
+ # Create a convolution kernel
+ kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
+ # Apply convolution to calculate the moving average
+ smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
+ return [smoothed_arr, arr - smoothed_arr]
+
+
+__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]
diff --git a/tests/models/timesfm/__init__.py b/tests/models/timesfm/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py
new file mode 100644
index 00000000000..6d69a97d352
--- /dev/null
+++ b/tests/models/timesfm/test_modeling_timesfm.py
@@ -0,0 +1,197 @@
+# coding=utf-8
+# Copyright 2025 Google LLC and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+from typing import List
+
+import numpy as np
+import torch
+
+from transformers import TimesFmConfig, is_torch_available
+from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.utils import is_torch_fx_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin
+
+
+if is_torch_fx_available():
+ pass
+
+if is_torch_available():
+ from transformers import TimesFmModelForPrediction
+
+TOLERANCE = 1e-4
+
+
+class TimesFmModelTester:
+ def __init__(
+ self,
+ parent,
+ patch_length: int = 32,
+ context_length: int = 512,
+ horizon_length: int = 128,
+ freq_size: int = 3,
+ num_hidden_layers: int = 1,
+ hidden_size: int = 16,
+ intermediate_size: int = 32,
+ head_dim: int = 8,
+ num_heads: int = 2,
+ tolerance: float = 1e-6,
+ rms_norm_eps: float = 1e-6,
+ quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
+ pad_val: float = 1123581321.0,
+ use_positional_embedding: bool = True,
+ initializer_factor: float = 0.0,
+ is_training: bool = False,
+ batch_size: int = 3,
+ ):
+ self.parent = parent
+ self.patch_length = patch_length
+ self.context_length = context_length
+ self.horizon_length = horizon_length
+ self.quantiles = quantiles
+ self.pad_val = pad_val
+ self.freq_size = freq_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.head_dim = head_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_heads
+ self.tolerance = tolerance
+ self.rms_norm_eps = rms_norm_eps
+ self.use_positional_embedding = use_positional_embedding
+ self.initializer_factor = initializer_factor
+ self.is_training = is_training
+ self.batch_size = batch_size
+
+ # The size of test input
+ self.seq_length = context_length // patch_length
+ self.hidden_size = hidden_size
+
+ def get_config(self):
+ return TimesFmConfig(
+ patch_length=self.patch_length,
+ context_length=self.context_length,
+ horizon_length=self.horizon_length,
+ quantiles=self.quantiles,
+ pad_val=self.pad_val,
+ freq_size=self.freq_size,
+ hidden_size=self.hidden_size,
+ intermediate_size=self.intermediate_size,
+ head_dim=self.head_dim,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ tolerance=self.tolerance,
+ rms_norm_eps=self.rms_norm_eps,
+ use_positional_embedding=self.use_positional_embedding,
+ initializer_factor=self.initializer_factor,
+ )
+
+ def get_pipeline_config(self):
+ return self.get_config()
+
+ def prepare_config_and_inputs(self):
+ forecast_input = [
+ torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
+ torch.tensor(np.cos(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
+ torch.tensor(np.tan(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
+ ]
+ frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device)
+
+ return (self.get_config(), torch.stack(forecast_input, dim=0), frequency_input)
+
+ def prepare_config_and_inputs_for_common(self):
+ (config, forecast_input, frequency_input) = self.prepare_config_and_inputs()
+
+ inputs_dict = {
+ "past_values": forecast_input,
+ "freq": frequency_input,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class TimesFmModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else ()
+ all_generative_model_classes = ()
+ all_parallelizable_model_classes = ()
+ fx_compatible = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_model_parallel = False
+ is_encoder_decoder = False
+ test_inputs_embeds = False
+
+ def setUp(self):
+ self.model_tester = TimesFmModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=TimesFmConfig)
+
+ def test_create_and_run_model(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = TimesFmModelForPrediction(config)
+ model.to(torch_device)
+ model.eval()
+ results = model(**inputs_dict)
+ assert results.mean_predictions is not None
+
+ @unittest.skip(reason="Compile not yet supported because of masks")
+ def test_sdpa_can_dispatch_on_flash(self):
+ pass
+
+ @unittest.skip(reason="Model does not have input embeddings")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @unittest.skip(reason="Model does not have head mask")
+ def test_headmasking(self):
+ pass
+
+ # the main input name is `inputs`
+ def test_model_main_input_name(self):
+ model_signature = inspect.signature(getattr(TimesFmModelForPrediction, "forward"))
+ # The main input is the name of the argument after `self`
+ observed_main_input_name = list(model_signature.parameters.keys())[1]
+ self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name)
+
+
+@require_torch
+@slow
+class TimesFmModelIntegrationTests(unittest.TestCase):
+ def test_inference_no_head(self):
+ model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch", revision="refs/pr/7").to(
+ torch_device
+ )
+ forecast_input = [
+ np.sin(np.linspace(0, 20, 100)),
+ np.sin(np.linspace(0, 20, 200)),
+ np.sin(np.linspace(0, 20, 400)),
+ ]
+ forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32, device=torch_device) for ts in forecast_input]
+ frequency_input = [0, 1, 2]
+
+ with torch.no_grad():
+ output = model(past_values=forecast_input_tensor, freq=frequency_input).last_hidden_state
+
+ self.assertEqual(
+ output.shape,
+ torch.Size([3, model.config.context_length // model.config.patch_length, model.config.hidden_size]),
+ )
+ expected_slice = torch.tensor(
+ [[-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932]],
+ device=torch_device,
+ )
+ self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
diff --git a/utils/check_repo.py b/utils/check_repo.py
index aeba4ee73de..85178b663e4 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -156,6 +156,7 @@ IGNORE_NON_TESTED = (
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Emu3VQVAE", # Building part of bigger (tested) model
"Emu3TextModel", # Building part of bigger (tested) model
+ "TimesFmModel", # Building part of bigger (tested) model
]
)