From a91020aed0b15794d0842e5799ec9d360e939f4e Mon Sep 17 00:00:00 2001 From: Jinan Zhou Date: Wed, 16 Apr 2025 06:00:53 -0700 Subject: [PATCH] Add TimesFM Time Series Forecasting Model (#34082) * initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * initial copy from t5 * added config and attention layers * add TimesFMPositionalEmbedding * calcuate scale_factor once * add more configs and TimesFMResidualBlock * fix input_dims * standardize code format with black * remove unneeded modules * TimesFM Model * order of imports * copy from Google official implementation * remove covariate forecasting * Adapting TimesFM to HF format * restructing in progress * adapted to HF convention * timesfm test * the model runs * fixing unit tests * fixing unit tests in progress * add post_init * do not change TimesFMOutput * fixing unit tests * all unit tests passed * remove timesfm_layers * add intermediate_size and initialize with config * initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * add _CHECKPOINT_FOR_DOC * fix comments * Revert "fix comments" This reverts commit 8deeb3e191b3671bc1d74dbfe77b736a066c3d34. * add _prepare_4d_attention_mask * we do not have generative model classes * use Cache * return past_key_values * modules initialized with config only * update year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add layer_idx to cache * modular timesfm * fix test * unwrap sequential class * fix toctree * remove TimesFmOnnxConfig * fix modular * remove TimesFmStackedDecoder * split qkv layer into individual layers * rename projection layers * use ALL_ATTENTION_FUNCTIONS * is_causal is True * rename config * does not support flash_attn_2 * formatting * fix typo in docsstring * rename inputs * add time series mapping * Update src/transformers/models/olmo2/modeling_olmo2.py * Update src/transformers/models/moonshine/modeling_moonshine.py * use updated arguments * fix class name * add MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING * isort * consolidate _preprocess into forward * fix a typo * fix a typo * fix toc * fix modular * remove aaserts * use self.config._attn_implementation * move to _postprocess_output * remove timesfm_get_large_negative_number * use view unstead of multiple unsqueeze * make helpers static methods of the Model * use to_tuple * use to_tuple if not return_dict * remove unused intitialization block as its incorporated in nn.Linear * remove unused num_key_value_groups * use the same convention as the masking method * update modular * do not use unsqueeze * use view instead of unsqueeze * use buffer for inv_timescales * formatting * modular conversion * remove unneeded intialization * add missing docstrings * remove cache * use simple_eager_attention_forward * support tp_plan * support for flex and flash attention masks * Revert "support for flex and flash attention masks" This reverts commit def36c4fcf31599b3f4937c9334b7da1a20132c3. * fix device * fix tests on gpu * remove unsued large model test * removed unneeded comments * add example usage * fix style * add import * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez * inherit from LlamaRMSNorm * use can_return_tuple decorator * remvoe return_dict * fix year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez * pretrained does not inherit from GenerationMixin * use model for integration test --------- Co-authored-by: Kashif Rasul Co-authored-by: Rajat Sen Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Cyril Vallez Co-authored-by: Cyril Vallez --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/timesfm.md | 88 ++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 21 + src/transformers/models/timesfm/__init__.py | 27 + .../models/timesfm/configuration_timesfm.py | 129 +++ .../timesfm/convert_timesfm_orignal_to_hf.py | 275 ++++++ .../models/timesfm/modeling_timesfm.py | 904 ++++++++++++++++++ .../models/timesfm/modular_timesfm.py | 860 +++++++++++++++++ tests/models/timesfm/__init__.py | 0 tests/models/timesfm/test_modeling_timesfm.py | 197 ++++ utils/check_repo.py | 1 + 13 files changed, 2507 insertions(+) create mode 100644 docs/source/en/model_doc/timesfm.md create mode 100644 src/transformers/models/timesfm/__init__.py create mode 100644 src/transformers/models/timesfm/configuration_timesfm.py create mode 100644 src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py create mode 100644 src/transformers/models/timesfm/modeling_timesfm.py create mode 100644 src/transformers/models/timesfm/modular_timesfm.py create mode 100644 tests/models/timesfm/__init__.py create mode 100644 tests/models/timesfm/test_modeling_timesfm.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 91522f0fc30..0bd54e5a3b1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1059,6 +1059,8 @@ title: PatchTST - local: model_doc/time_series_transformer title: Time Series Transformer + - local: model_doc/timesfm + title: TimesFM title: Time series models - sections: - local: model_doc/graphormer diff --git a/docs/source/en/model_doc/timesfm.md b/docs/source/en/model_doc/timesfm.md new file mode 100644 index 00000000000..f5e27994919 --- /dev/null +++ b/docs/source/en/model_doc/timesfm.md @@ -0,0 +1,88 @@ + + +# TimesFM + +
+PyTorch +
+ +## Overview + +TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model proposed in [A decoder-only foundation model for time-series forecasting](https://huggingface.co/papers/2310.10688) by Abhimanyu Das, Weihao Kong, Rajat Sen, and Yichen Zhou. It is a decoder only model that uses non-overlapping patches of time-series data as input and outputs some output patch length prediction in an autoregressive fashion. + + +The abstract from the paper is the following: + +*Motivated by recent advances in large language models for Natural Language Processing (NLP), we design a time-series foundation model for forecasting whose out-of-the-box zero-shot performance on a variety of public datasets comes close to the accuracy of state-of-the-art supervised forecasting models for each individual dataset. Our model is based on pretraining a patched-decoder style attention model on a large time-series corpus, and can work well across different forecasting history lengths, prediction lengths and temporal granularities.* + + +This model was contributed by [kashif](https://huggingface.co/kashif). +The original code can be found [here](https://github.com/google-research/timesfm). + + +To use the model: + +```python +import torch +from transformers import TimesFmModelForPrediction + + +model = TimesFmModelForPrediction.from_pretrained( + "google/timesfm-2.0-500m-pytorch", + torch_dtype=torch.bfloat16, + attn_implementation="sdpa", + device_map="cuda" if torch.cuda.is_available() else None +) + + + # Create dummy inputs +forecast_input = [ + np.sin(np.linspace(0, 20, 100)), + np.sin(np.linspace(0, 20, 200)), + np.sin(np.linspace(0, 20, 400)), +] +frequency_input = [0, 1, 2] + +# Convert inputs to sequence of tensors +forecast_input_tensor = [ + torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu") + for ts in forecast_input +] +frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to( + "cuda" if torch.cuda.is_available() else "cpu" +) + +# Get predictions from the pre-trained model +with torch.no_grad(): + outputs = model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True) + point_forecast_conv = outputs.mean_predictions.float().cpu().numpy() + quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy() +``` + +## TimesFmConfig + +[[autodoc]] TimesFmConfig + +## TimesFmModel + +[[autodoc]] TimesFmModel + - forward + +## TimesFmModelForPrediction + +[[autodoc]] TimesFmModelForPrediction + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 73961f4a6a8..94a68374cb7 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -279,6 +279,7 @@ if TYPE_CHECKING: from .tapas import * from .textnet import * from .time_series_transformer import * + from .timesfm import * from .timesformer import * from .timm_backbone import * from .timm_wrapper import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 1fbc2cb168f..c28ec163d1c 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -313,6 +313,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("tapas", "TapasConfig"), ("textnet", "TextNetConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"), + ("timesfm", "TimesFmConfig"), ("timesformer", "TimesformerConfig"), ("timm_backbone", "TimmBackboneConfig"), ("timm_wrapper", "TimmWrapperConfig"), @@ -681,6 +682,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("tapex", "TAPEX"), ("textnet", "TextNet"), ("time_series_transformer", "Time Series Transformer"), + ("timesfm", "TimesFm"), ("timesformer", "TimeSformer"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e630333602c..af832ee2393 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -281,6 +281,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("tapas", "TapasModel"), ("textnet", "TextNetModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), + ("timesfm", "TimesFmModel"), ("timesformer", "TimesformerModel"), ("timm_backbone", "TimmBackbone"), ("timm_wrapper", "TimmWrapperModel"), @@ -1542,6 +1543,12 @@ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict( + [ + ("timesfm", "TimesFmModelForPrediction"), + ] +) + MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict( [ ("swin2sr", "Swin2SRForImageSuperResolution"), @@ -1650,6 +1657,10 @@ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES ) +MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES +) + MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES) @@ -1820,6 +1831,15 @@ AutoModelForSemanticSegmentation = auto_class_update( ) +class AutoModelForTimeSeriesPrediction(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING + + +AutoModelForTimeSeriesPrediction = auto_class_update( + AutoModelForTimeSeriesPrediction, head_doc="time-series prediction" +) + + class AutoModelForUniversalSegmentation(_BaseAutoModelClass): _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING @@ -1994,6 +2014,7 @@ __all__ = [ "MODEL_FOR_TEXT_ENCODING_MAPPING", "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING", "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING", + "MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", diff --git a/src/transformers/models/timesfm/__init__.py b/src/transformers/models/timesfm/__init__.py new file mode 100644 index 00000000000..12f1541b9c5 --- /dev/null +++ b/src/transformers/models/timesfm/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_timesfm import * + from .modeling_timesfm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/timesfm/configuration_timesfm.py b/src/transformers/models/timesfm/configuration_timesfm.py new file mode 100644 index 00000000000..bd371cf1b21 --- /dev/null +++ b/src/transformers/models/timesfm/configuration_timesfm.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2025 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TimesFM model configuration""" + +from typing import List + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TimesFmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TimesFmModelForPrediction`] or a [`TFTimesFmModel`]. It is used to + instantiate a TimesFM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the TimesFM + [google/timesfm-2.0-500m-pytorch](https://huggingface.co/google/timesfm-2.0-500m-pytorch) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + patch_length (`int`, *optional*, defaults to 32): + The length of one patch in the input sequence. + context_length (`int`, *optional*, defaults to 512): + The length of the input context. + horizon_length (`int`, *optional*, defaults to 128): + The length of the prediction horizon. + freq_size (`int`, *optional*, defaults to 3): + The number of frequency embeddings. + num_hidden_layers (`int`, *optional*, defaults to 50): + Number of Transformer layers. + hidden_size (`int`, *optional*, defaults to 1280): + Size of the hidden layers in the feed-forward networks. + intermediate_size (`int`, *optional*, defaults to 1280): + Dimension of the MLP representations. + head_dim (`int`, *optional*, defaults to 80): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_attention_heads * head_dim`. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + tolerance (`float`, *optional*, defaults to 1e-06): + The tolerance for the quantile loss. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the RMS normalization layers. + quantiles (`List[float]`, *optional*, defaults to `[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]`): + The quantiles to predict. + pad_val (`float`, *optional*, defaults to 1123581321.0): + The value used to pad the predictions. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention scores. + use_positional_embedding (`bool`, *optional*, defaults to `False`): + Whether to add positional embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + min_timescale (`int`, *optional*, defaults to 1): + The start of the geometric positional index. Determines the periodicity of + the added signal. + max_timescale (`int`, *optional*, defaults to 10000): + The end of the geometric positional index. Determines the frequency of the + added signal. + """ + + model_type = "timesfm" + keys_to_ignore_at_inference = [] + is_encoder_decoder = False + + def __init__( + self, + patch_length: int = 32, + context_length: int = 512, + horizon_length: int = 128, + freq_size: int = 3, + num_hidden_layers: int = 50, + hidden_size: int = 1280, + intermediate_size: int = 1280, + head_dim: int = 80, + num_attention_heads: int = 16, + tolerance: float = 1e-6, + rms_norm_eps: float = 1e-6, + quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + pad_val: float = 1123581321.0, + attention_dropout: float = 0.0, + use_positional_embedding: bool = False, + initializer_range: float = 0.02, + min_timescale: int = 1, + max_timescale: int = 10_000, + **kwargs, + ): + self.patch_length = patch_length + self.context_length = context_length + self.horizon_length = horizon_length + self.quantiles = quantiles + self.pad_val = pad_val + self.freq_size = freq_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.tolerance = tolerance + self.rms_norm_eps = rms_norm_eps + self.attention_dropout = attention_dropout + self.use_positional_embedding = use_positional_embedding + self.initializer_range = initializer_range + self.min_timescale = min_timescale + self.max_timescale = max_timescale + + super().__init__( + is_encoder_decoder=self.is_encoder_decoder, + **kwargs, + ) + + +__all__ = ["TimesFmConfig"] diff --git a/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py new file mode 100644 index 00000000000..06674fb087c --- /dev/null +++ b/src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py @@ -0,0 +1,275 @@ +import argparse +import os +import re +import shutil + +import numpy as np +import timesfm +import torch + +from transformers import TimesFmConfig, TimesFmModelForPrediction + + +""" +Sample usage: + +``` +python src/transformers/models/timesfm/convert_timesfm_orignal_to_hf.py \ + --output_dir /output/path +``` +""" + + +def get_nested_attr(obj, key): + """Recursively retrieves an attribute from an object, handling list/tuple indexing if present.""" + parts = key.split(".") + for part in parts: + match = re.match(r"(.*)\[(\d+)\]", part) # Handle list indexing like `layers[0]` + if match: + attr_name, index = match.groups() + obj = getattr(obj, attr_name)[int(index)] # Access list/tuple element + else: + obj = getattr(obj, part) # Regular attribute access + return obj + + +def write_model(model_path, safe_serialization=True, huggingface_repo_id="google/timesfm-2.0-500m-pytorch"): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + tfm = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="cuda" if torch.cuda.is_available() else "cpu", + per_core_batch_size=32, + horizon_len=128, + input_patch_len=32, + output_patch_len=128, + num_layers=50, + model_dims=1280, + use_positional_embedding=False, + ), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), + ) + + timesfm_config = TimesFmConfig( + patch_length=tfm.hparams.input_patch_len, + context_length=tfm.hparams.context_len, + horizon_length=tfm.hparams.horizon_len, + num_hidden_layers=tfm.hparams.num_layers, + hidden_size=tfm.hparams.model_dims, + intermediate_size=tfm.hparams.model_dims, + head_dim=tfm.hparams.model_dims // tfm.hparams.num_heads, + num_attention_heads=tfm.hparams.num_heads, + use_positional_embedding=tfm.hparams.use_positional_embedding, + ) + timesfm_config.save_pretrained(tmp_model_path) + timesfm_model = TimesFmModelForPrediction(timesfm_config) + + # copy the weights from the original model to the new model making + original_model = tfm._model + + # mapping of the layers from the original model to the transformer model + MODEL_LAYER_MAPPING = { + "input_ff_layer.hidden_layer[0].weight": "decoder.input_ff_layer.input_layer.weight", + "input_ff_layer.hidden_layer[0].bias": "decoder.input_ff_layer.input_layer.bias", + "input_ff_layer.output_layer.weight": "decoder.input_ff_layer.output_layer.weight", + "input_ff_layer.output_layer.bias": "decoder.input_ff_layer.output_layer.bias", + "input_ff_layer.residual_layer.weight": "decoder.input_ff_layer.residual_layer.weight", + "input_ff_layer.residual_layer.bias": "decoder.input_ff_layer.residual_layer.bias", + "freq_emb.weight": "decoder.freq_emb.weight", + "horizon_ff_layer.hidden_layer[0].weight": "horizon_ff_layer.input_layer.weight", + "horizon_ff_layer.hidden_layer[0].bias": "horizon_ff_layer.input_layer.bias", + "horizon_ff_layer.output_layer.weight": "horizon_ff_layer.output_layer.weight", + "horizon_ff_layer.output_layer.bias": "horizon_ff_layer.output_layer.bias", + "horizon_ff_layer.residual_layer.weight": "horizon_ff_layer.residual_layer.weight", + "horizon_ff_layer.residual_layer.bias": "horizon_ff_layer.residual_layer.bias", + } + + TRANSFORMER_LAYER_MAPPING = { + "stacked_transformer.layers[{i}].self_attn.qkv_proj.weight": "decoder.layers[{i}].self_attn.qkv_proj.weight", + "stacked_transformer.layers[{i}].self_attn.qkv_proj.bias": "decoder.layers[{i}].self_attn.qkv_proj.bias", + "stacked_transformer.layers[{i}].self_attn.o_proj.weight": "decoder.layers[{i}].self_attn.o_proj.weight", + "stacked_transformer.layers[{i}].self_attn.o_proj.bias": "decoder.layers[{i}].self_attn.o_proj.bias", + "stacked_transformer.layers[{i}].self_attn.scaling": "decoder.layers[{i}].self_attn.scaling", + "stacked_transformer.layers[{i}].mlp.gate_proj.weight": "decoder.layers[{i}].mlp.gate_proj.weight", + "stacked_transformer.layers[{i}].mlp.gate_proj.bias": "decoder.layers[{i}].mlp.gate_proj.bias", + "stacked_transformer.layers[{i}].mlp.down_proj.weight": "decoder.layers[{i}].mlp.down_proj.weight", + "stacked_transformer.layers[{i}].mlp.down_proj.bias": "decoder.layers[{i}].mlp.down_proj.bias", + "stacked_transformer.layers[{i}].mlp.layer_norm.weight": "decoder.layers[{i}].mlp.layer_norm.weight", + "stacked_transformer.layers[{i}].mlp.layer_norm.bias": "decoder.layers[{i}].mlp.layer_norm.bias", + "stacked_transformer.layers[{i}].input_layernorm.weight": "decoder.layers[{i}].input_layernorm.weight", + } + + for old_key, new_key in MODEL_LAYER_MAPPING.items(): + try: + old_attr = get_nested_attr(original_model, old_key) # Get tensor from original model + new_attr = get_nested_attr(timesfm_model, new_key) # Get corresponding attribute in new model + new_attr.data.copy_(old_attr.data) # Copy data + except AttributeError: + print(f"Skipping {old_key} (not found in original model).") + + num_layers = len(timesfm_model.decoder.layers) + for i in range(num_layers): + for old_template, new_template in TRANSFORMER_LAYER_MAPPING.items(): + old_key = old_template.format(i=i) + new_key = new_template.format(i=i) + + try: + # Get tensor from original model + old_attr = get_nested_attr(original_model, old_key) + if "qkv_proj" in old_key: + # Split the tensor into q, k, v projections + q_proj, k_proj, v_proj = ( + old_attr[: tfm.hparams.model_dims, ...], + old_attr[tfm.hparams.model_dims : tfm.hparams.model_dims * 2, ...], + old_attr[tfm.hparams.model_dims * 2 :, ...], + ) + # Get corresponding attribute in new model + q_key = new_key.replace("qkv_proj", "q_proj") + q_attr = get_nested_attr(timesfm_model, q_key) + q_attr.data.copy_(q_proj.data) # Copy data + k_key = new_key.replace("qkv_proj", "k_proj") + k_attr = get_nested_attr(timesfm_model, k_key) + k_attr.data.copy_(k_proj.data) # Copy data + v_key = new_key.replace("qkv_proj", "v_proj") + v_attr = get_nested_attr(timesfm_model, v_key) + v_attr.data.copy_(v_proj.data) # Copy data + else: + # Get corresponding attribute in new model + new_attr = get_nested_attr(timesfm_model, new_key) + new_attr.data.copy_(old_attr.data) # Copy data + except AttributeError: + print(f"Skipping {old_key} (not found in original model).") + + timesfm_model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def check_outputs(model_path, huggingface_repo_id): + """Compares outputs between original and converted models.""" + print("\nChecking model outputs...") + + # Load original model + tfm = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="cuda" if torch.cuda.is_available() else "cpu", + per_core_batch_size=32, + horizon_len=128, + input_patch_len=32, + output_patch_len=128, + num_layers=50, + model_dims=1280, + use_positional_embedding=False, + point_forecast_mode="mean", + ), + checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id), + ) + + # Load converted model + converted_model = TimesFmModelForPrediction.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + attn_implementation="sdpa", + ).to("cuda" if torch.cuda.is_available() else "cpu") + converted_model.eval() # Set to evaluation mode + + # Create test inputs + forecast_input = [ + np.sin(np.linspace(0, 20, 100)), + np.sin(np.linspace(0, 20, 200)), + np.sin(np.linspace(0, 20, 400)), + ] + frequency_input = [0, 1, 2] + + # Get predictions from original model + point_forecast_orig, quantile_forecast_orig = tfm.forecast( + forecast_input, + freq=frequency_input, + ) + + # Convert inputs to sequence of tensors + forecast_input_tensor = [ + torch.tensor(ts, dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu") + for ts in forecast_input + ] + frequency_input_tensor = torch.tensor(frequency_input, dtype=torch.long).to( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + # Get predictions from converted model + with torch.no_grad(): + outputs = converted_model(past_values=forecast_input_tensor, freq=frequency_input_tensor, return_dict=True) + point_forecast_conv = outputs.mean_predictions.float().cpu().numpy() + quantile_forecast_conv = outputs.full_predictions.float().cpu().numpy() + + # Compare outputs + point_forecast_diff = np.abs(point_forecast_orig - point_forecast_conv) + quantile_forecast_diff = np.abs(quantile_forecast_orig - quantile_forecast_conv) + + max_point_diff = point_forecast_diff.max() + mean_point_diff = point_forecast_diff.mean() + max_quantile_diff = quantile_forecast_diff.max() + mean_quantile_diff = quantile_forecast_diff.mean() + + print("\nOutput comparison:") + print(f"Point forecast - Max difference: {max_point_diff:.6f}") + print(f"Point forecast - Mean difference: {mean_point_diff:.6f}") + print(f"Quantile forecast - Max difference: {max_quantile_diff:.6f}") + print(f"Quantile forecast - Mean difference: {mean_quantile_diff:.6f}") + + # Define acceptable thresholds + POINT_THRESHOLD = 1e-5 + QUANTILE_THRESHOLD = 1e-5 + + if max_point_diff > POINT_THRESHOLD or max_quantile_diff > QUANTILE_THRESHOLD: + raise ValueError( + f"Output mismatch detected!\n" + f"Point forecast max diff: {max_point_diff} (threshold: {POINT_THRESHOLD})\n" + f"Quantile forecast max diff: {max_quantile_diff} (threshold: {QUANTILE_THRESHOLD})" + ) + + print("\n✓ All outputs match within acceptable tolerance!") + + # Optional: Print shapes for verification + print("\nOutput shapes:") + print(f"Original point forecast: {point_forecast_orig.shape}") + print(f"Converted point forecast: {point_forecast_conv.shape}") + print(f"Original quantile forecast: {quantile_forecast_orig.shape}") + print(f"Converted quantile forecast: {quantile_forecast_conv.shape}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + required=True, + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`." + ) + parser.add_argument( + "--huggingface_repo_id", + type=str, + default="google/timesfm-2.0-500m-pytorch", + help="The Hugging Face repository ID to use for the model.", + ) + args = parser.parse_args() + + # if the saved model file exists, skip the conversion + if os.path.exists( + os.path.join(args.output_dir, "model.safetensors" if args.safe_serialization else "pytorch_model.bin") + ): + print(f"Model already exists in {args.output_dir}, skipping conversion.") + else: + write_model( + model_path=args.output_dir, + safe_serialization=args.safe_serialization, + huggingface_repo_id=args.huggingface_repo_id, + ) + check_outputs(args.output_dir, args.huggingface_repo_id) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py new file mode 100644 index 00000000000..8a12b2c56cf --- /dev/null +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -0,0 +1,904 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/timesfm/modular_timesfm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_timesfm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...integrations import use_kernel_forward_from_hub +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from .configuration_timesfm import TimesFmConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch" +_CONFIG_FOR_DOC = "TimesFmConfig" + + +@dataclass +class TimesFmOutput(BaseModelOutput): + """ + Args: + loc (`torch.Tensor` of shape `(batch_size, )`): + The mean of the time series inputs. + scale (`torch.Tensor` of shape `(batch_size,)`): + The scale of the time series inputs. + """ + + loc: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + + +@dataclass +class TimesFmOutputForPrediction(BaseModelOutput): + """ + Args: + mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The mean predictions of the time series. + full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The full predictions of the time series including the mean and the quantiles. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided): + The loss of the TimesFM model. + """ + + mean_predictions: Optional[torch.Tensor] = None + full_predictions: Optional[torch.Tensor] = None + loss: Optional[Union[torch.Tensor, float]] = None + + +class TimesFmMLP(nn.Module): + """Pax MLP in pytorch.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFmResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__(self, input_dims, hidden_dims, output_dims): + super().__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + self.input_layer = nn.Linear(input_dims, hidden_dims) + self.activation = nn.SiLU() + self.output_layer = nn.Linear(hidden_dims, output_dims) + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.input_layer(x) + hidden = self.activation(hidden) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +@use_kernel_forward_from_hub("RMSNorm") +class TimesFmRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + TimesFmRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class TimesFmPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + min_timescale = config.min_timescale + max_timescale = config.max_timescale + self.embedding_dims = config.hidden_size + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) + self.register_buffer( + "inv_timescales", + min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment), + ) + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None and seq_length is None: + raise ValueError("Either position or seq_length must be provided") + + if position is None: + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0) + elif position.ndim != 2: + raise ValueError(f"position must be 2-dimensional, got shape {position.shape}") + + scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +def simple_eager_attention_forward( + module: nn.Module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class TimesFmAttention(nn.Module): + """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + self.config = config + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.layer_idx = layer_idx + + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_heads * self.head_dim + self.scaling = nn.Parameter(torch.empty((self.head_dim,))) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _scale_query(self, query: torch.Tensor) -> torch.Tensor: + scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim)) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self._scale_query(query_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class TimesFmDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + + self.self_attn = TimesFmAttention(config, layer_idx=layer_idx) + self.mlp = TimesFmMLP(config) + self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + paddings: torch.Tensor, + output_attentions: bool = False, + ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, scores = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +TIMESFM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TimesFmConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFmPreTrainedModel(PreTrainedModel): + """handles the loading for all models.""" + + config_class = TimesFmConfig + base_model_prefix = "timesfm" + _no_split_modules = ["TimesFmDecoderLayer"] + main_input_name = "past_values" + _supports_sdpa = True + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0, std=self.config.initializer_range) + + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + elif isinstance(module, TimesFmRMSNorm): + nn.init.zeros_(module.weight) + + elif isinstance(module, TimesFmAttention): + # Initialize scaling parameter + nn.init.ones_(module.scaling) + + +TIMESFM_INPUTS_DOCSTRING = r""" + Args: + past_values: list of time series forecast contexts. Each context time series + can be a torch Tensor of potentially different context lengths. + freq: frequency of each context time series in the inputs. 0 for high frequency + (default), 1 for medium, and 2 for low. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFmModel(TimesFmPreTrainedModel): + """Patched time-series decoder without any specific output layer.""" + + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.input_ff_layer = TimesFmResidualBlock( + input_dims=2 * config.patch_length, + output_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + ) + self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size) + self.layers = nn.ModuleList( + [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + if self.config.use_positional_embedding: + self.position_emb = TimesFmPositionalEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device), + outputs, + ) + return outputs, (mu, sigma) + + @can_return_tuple + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + def forward( + self, + past_values: torch.Tensor, + past_values_padding: torch.LongTensor, + freq: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> TimesFmOutput: + """ + past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The padding indicator of the time series. + """ + # Reshape into patches (using view for efficiency) + bsize = past_values.shape[0] + patched_inputs = past_values.view(bsize, -1, self.config.patch_length) + patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + + # Convert paddings to attention mask and combine with causal mask + hidden_states = model_input + attention_mask = self._prepare_4d_attention_mask( + attention_mask=patched_padding, + sequence_length=hidden_states.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + is_causal=True, + ) + + all_attentions = [] + all_hidden_states = [] + + for layer in self.layers[: self.config.num_hidden_layers]: + scores, hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + paddings=patched_padding, + output_attentions=output_attentions, + ) + if output_attentions: + all_attentions.append(scores) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + if output_hidden_states: + all_hidden_states = [model_input] + all_hidden_states + else: + all_hidden_states = None + + return TimesFmOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + + @staticmethod + def _prepare_4d_attention_mask( + attention_mask: Optional[torch.Tensor], + sequence_length: int, + dtype: torch.dtype, + device: torch.device, + is_causal: bool = True, + ) -> Optional[torch.Tensor]: + """ + Creates 4D attention mask and combines causal and padding masks if needed. + + Args: + attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask + sequence_length: Length of the sequence + dtype: Data type of the mask + device: Device of the mask + is_causal: Whether to apply causal masking + + Returns: + 4D attention mask of shape (batch_size, 1, seq_length, seq_length) + """ + # Get minimum value for the dtype + min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min + + # Handle padding mask + if attention_mask is not None: + # Convert 2D padding mask to 4D attention mask + attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) + attention_mask = attention_mask * min_value + + # Create causal mask if needed + if is_causal: + causal_mask = torch.triu( + torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value, + diagonal=1, + ) + causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length) + + # Combine with padding mask if it exists + if attention_mask is not None: + attention_mask = torch.minimum(attention_mask, causal_mask) + else: + attention_mask = causal_mask + + return attention_mask + + @staticmethod + def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. + """ + + # Selecting the first patch with more than 3 unpadded values. + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + pad_sum = torch.sum(1 - padding, dim=2) + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + @staticmethod + def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + The shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +class TimesFmModelForPrediction(TimesFmPreTrainedModel): + """TimesFM model for quantile and mean prediction.""" + + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.context_len = config.context_length + self.horizon_len = config.horizon_length + + self.decoder = TimesFmModel(config) + + # quantile and mean output + self.horizon_ff_layer = TimesFmResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_length * (1 + len(config.quantiles)), + hidden_dims=config.intermediate_size, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess( + self, inputs: Sequence[torch.Tensor], freq: Sequence[int] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d Tensors. Each Tensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + input_ts, input_padding, inp_freq = [], [], [] + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + return ( + torch.stack(input_ts, dim=0), + torch.stack(input_padding, dim=0), + torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), + ) + + def _postprocess_output( + self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1) + + mu, sigma = stats + return output_ts * sigma[:, None, None, None] + mu[:, None, None, None] + + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + losses = [] + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[..., i] + loss = torch.max((q - 1) * errors, q * errors) + losses.append(loss.mean()) + return torch.stack(losses).mean() + + @can_return_tuple + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TimesFmOutputForPrediction, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + past_values: Sequence[torch.Tensor], + freq: Optional[Sequence[Union[torch.Tensor, int]]] = None, + window_size: Optional[int] = None, + future_values: Optional[torch.Tensor] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TimesFmOutputForPrediction: + r""" + window_size (`int`, *optional*): + Window size of trend + residual decomposition. If None then we do not do decomposition. + future_values (`torch.Tensor`, *optional*): + Optional future time series values to be used for loss computation. + forecast_context_len (`int`, *optional*): + Optional max context length. + return_forecast_on_context (`bool`, *optional*): + True to return the forecast on the context when available, i.e. after the first input patch. + truncate_negative (`bool`, *optional*): + Truncate to only non-negative values if any of the contexts have non-negative values, + otherwise do nothing. + output_attentions (`bool`, *optional*): + Whether to output the attentions. + output_hidden_states (`bool`, *optional*): + Whether to output the hidden states. + + Returns: + A TimesFmOutputForPrediction object or a tuple containing: + - the mean forecast of size (# past_values, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# past_values, # forecast horizon, 1 + # quantiles). + - loss: the mean squared error loss + quantile loss if `future_values` is provided. + """ + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + + # Get device from first input tensor + device = past_values[0].device + + # Truncate inputs to forecast_context_len + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + + if window_size is not None: + new_inputs = [] + new_freqs = [] + for i, ts in enumerate(inputs): + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) + inputs = new_inputs + if freq is not None: + freq = new_freqs + + if freq is None: + logger.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + if output_attentions is None: + output_attentions = self.config.output_attentions + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) + # Move tensors to the same device as input + input_ts = input_ts.to(device) + input_padding = input_padding.to(device) + inp_freq = inp_freq.to(device) + + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + + if input_padding.shape[1] != final_out.shape[1] + self.horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}" + ) + output_patch_len = self.config.horizon_length + + num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = input_padding[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -fcontext_len:] + input_padding = current_padding[:, -fcontext_len:] + decoder_output = self.decoder( + past_values=input_ts, + past_values_padding=input_padding, + freq=inp_freq, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + fprop_outputs = self._postprocess_output( + decoder_output.last_hidden_state, + (decoder_output.loc, decoder_output.scale), + ) + + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :] + # We have to use reshape and not view for non-contiguous memory + new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_length + self.horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :] + + mean_outputs = full_outputs[:, :, 0] + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = torch.maximum(mean_outputs, 0.0) + full_outputs = torch.maximum(full_outputs, 0.0) + + loss = None + if future_values is not None: + mse_loss = F.mse_loss(mean_outputs, future_values) + quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values) + loss = mse_loss + quantile_loss + + return TimesFmOutputForPrediction( + last_hidden_state=decoder_output.last_hidden_state, + attentions=decoder_output.attentions if output_attentions else None, + hidden_states=decoder_output.hidden_states if output_hidden_states else None, + mean_predictions=mean_outputs, + full_predictions=full_outputs, + loss=loss, + ) + + @staticmethod + def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + # Apply convolution to calculate the moving average + smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() + return [smoothed_arr, arr - smoothed_arr] + + +__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py new file mode 100644 index 00000000000..4a627524849 --- /dev/null +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -0,0 +1,860 @@ +# coding=utf-8 +# Copyright 2025 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TimesFM model.""" + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from ..llama.modeling_llama import LlamaRMSNorm +from ..phi4_multimodal.modeling_phi4_multimodal import simple_eager_attention_forward +from .configuration_timesfm import TimesFmConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch" +_CONFIG_FOR_DOC = "TimesFmConfig" + + +@dataclass +class TimesFmOutput(BaseModelOutput): + """ + Args: + loc (`torch.Tensor` of shape `(batch_size, )`): + The mean of the time series inputs. + scale (`torch.Tensor` of shape `(batch_size,)`): + The scale of the time series inputs. + """ + + loc: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + + +@dataclass +class TimesFmOutputForPrediction(BaseModelOutput): + """ + Args: + mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The mean predictions of the time series. + full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The full predictions of the time series including the mean and the quantiles. + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided): + The loss of the TimesFM model. + """ + + mean_predictions: Optional[torch.Tensor] = None + full_predictions: Optional[torch.Tensor] = None + loss: Optional[Union[torch.Tensor, float]] = None + + +class TimesFmMLP(nn.Module): + """Pax MLP in pytorch.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(hidden_size, intermediate_size) + self.down_proj = nn.Linear(intermediate_size, hidden_size) + self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def forward(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFmResidualBlock(nn.Module): + """TimesFM residual block.""" + + def __init__(self, input_dims, hidden_dims, output_dims): + super().__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + self.input_layer = nn.Linear(input_dims, hidden_dims) + self.activation = nn.SiLU() + self.output_layer = nn.Linear(hidden_dims, output_dims) + self.residual_layer = nn.Linear(input_dims, output_dims) + + def forward(self, x): + hidden = self.input_layer(x) + hidden = self.activation(hidden) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class TimesFmRMSNorm(LlamaRMSNorm): + pass + + +class TimesFmPositionalEmbedding(nn.Module): + """Generates position embedding for a given 1-d sequence.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + min_timescale = config.min_timescale + max_timescale = config.max_timescale + self.embedding_dims = config.hidden_size + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) + self.register_buffer( + "inv_timescales", + min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment), + ) + + def forward(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None and seq_length is None: + raise ValueError("Either position or seq_length must be provided") + + if position is None: + # [1, seqlen] + position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0) + elif position.ndim != 2: + raise ValueError(f"position must be 2-dimensional, got shape {position.shape}") + + scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +class TimesFmAttention(nn.Module): + """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + self.config = config + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.layer_idx = layer_idx + + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_heads * self.head_dim + self.scaling = nn.Parameter(torch.empty((self.head_dim,))) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _scale_query(self, query: torch.Tensor) -> torch.Tensor: + scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim)) + return query * scale[None, None, None, :] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self._scale_query(query_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class TimesFmDecoderLayer(nn.Module): + """Transformer layer.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + + self.self_attn = TimesFmAttention(config, layer_idx=layer_idx) + self.mlp = TimesFmMLP(config) + self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + paddings: torch.Tensor, + output_attentions: bool = False, + ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, scores = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +TIMESFM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TimesFmConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFmPreTrainedModel(PreTrainedModel): + """handles the loading for all models.""" + + config_class = TimesFmConfig + base_model_prefix = "timesfm" + _no_split_modules = ["TimesFmDecoderLayer"] + main_input_name = "past_values" + _supports_sdpa = True + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0, std=self.config.initializer_range) + + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + elif isinstance(module, TimesFmRMSNorm): + nn.init.zeros_(module.weight) + + elif isinstance(module, TimesFmAttention): + # Initialize scaling parameter + nn.init.ones_(module.scaling) + + +TIMESFM_INPUTS_DOCSTRING = r""" + Args: + past_values: list of time series forecast contexts. Each context time series + can be a torch Tensor of potentially different context lengths. + freq: frequency of each context time series in the inputs. 0 for high frequency + (default), 1 for medium, and 2 for low. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. +""" + + +@add_start_docstrings( + "The bare TimesFM Model outputting raw hidden-states without any specific head on top.", + TIMESFM_START_DOCSTRING, +) +class TimesFmModel(TimesFmPreTrainedModel): + """Patched time-series decoder without any specific output layer.""" + + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.input_ff_layer = TimesFmResidualBlock( + input_dims=2 * config.patch_length, + output_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + ) + self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size) + self.layers = nn.ModuleList( + [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + if self.config.use_positional_embedding: + self.position_emb = TimesFmPositionalEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def _forward_transform( + self, inputs: torch.Tensor, patched_pads: torch.Tensor + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads) + sigma = torch.where( + sigma < self.config.tolerance, + torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = torch.where( + torch.abs(inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device), + outputs, + ) + return outputs, (mu, sigma) + + @can_return_tuple + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + def forward( + self, + past_values: torch.Tensor, + past_values_padding: torch.LongTensor, + freq: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> TimesFmOutput: + """ + past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The padding indicator of the time series. + """ + # Reshape into patches (using view for efficiency) + bsize = past_values.shape[0] + patched_inputs = past_values.view(bsize, -1, self.config.patch_length) + patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) + + patched_inputs = torch.where( + torch.abs(patched_pads - 1.0) < self.config.tolerance, + torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device), + patched_inputs, + ) + patched_pads = torch.where( + torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]) + pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + + # Convert paddings to attention mask and combine with causal mask + hidden_states = model_input + attention_mask = self._prepare_4d_attention_mask( + attention_mask=patched_padding, + sequence_length=hidden_states.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + is_causal=True, + ) + + all_attentions = [] + all_hidden_states = [] + + for layer in self.layers[: self.config.num_hidden_layers]: + scores, hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + paddings=patched_padding, + output_attentions=output_attentions, + ) + if output_attentions: + all_attentions.append(scores) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + if output_hidden_states: + all_hidden_states = [model_input] + all_hidden_states + else: + all_hidden_states = None + + return TimesFmOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + + @staticmethod + def _prepare_4d_attention_mask( + attention_mask: Optional[torch.Tensor], + sequence_length: int, + dtype: torch.dtype, + device: torch.device, + is_causal: bool = True, + ) -> Optional[torch.Tensor]: + """ + Creates 4D attention mask and combines causal and padding masks if needed. + + Args: + attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask + sequence_length: Length of the sequence + dtype: Data type of the mask + device: Device of the mask + is_causal: Whether to apply causal masking + + Returns: + 4D attention mask of shape (batch_size, 1, seq_length, seq_length) + """ + # Get minimum value for the dtype + min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min + + # Handle padding mask + if attention_mask is not None: + # Convert 2D padding mask to 4D attention mask + attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) + attention_mask = attention_mask * min_value + + # Create causal mask if needed + if is_causal: + causal_mask = torch.triu( + torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value, + diagonal=1, + ) + causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length) + + # Combine with padding mask if it exists + if attention_mask is not None: + attention_mask = torch.minimum(attention_mask, causal_mask) + else: + attention_mask = causal_mask + + return attention_mask + + @staticmethod + def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A PyTorch tensor of shape [b, n, p]. + padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. + """ + + # Selecting the first patch with more than 3 unpadded values. + def _get_patch_index(arr: torch.Tensor): + indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) + row_sum = (arr >= 3).to(torch.int32).sum(dim=1) + return torch.where(row_sum == 0, arr.shape[1] - 1, indices) + + pad_sum = torch.sum(1 - padding, dim=2) + patch_indices = _get_patch_index(pad_sum) + bidxs = torch.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = torch.sum(mask, dim=1) + num_valid_elements = torch.where( + num_valid_elements == 0, + torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device), + num_valid_elements, + ) + + # Calculate the masked sum and squared sum + masked_sum = torch.sum(arr * mask, dim=1) + masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = torch.where( + masked_var < 0.0, + torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), + masked_var, + ) + masked_std = torch.sqrt(masked_var) + + return masked_mean, masked_std + + @staticmethod + def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + The shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: torch.BoolTensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(torch.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +class TimesFmModelForPrediction(TimesFmPreTrainedModel): + """TimesFM model for quantile and mean prediction.""" + + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.context_len = config.context_length + self.horizon_len = config.horizon_length + + self.decoder = TimesFmModel(config) + + # quantile and mean output + self.horizon_ff_layer = TimesFmResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_length * (1 + len(config.quantiles)), + hidden_dims=config.intermediate_size, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess( + self, inputs: Sequence[torch.Tensor], freq: Sequence[int] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d Tensors. Each Tensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + input_ts, input_padding, inp_freq = [], [], [] + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0) + padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + return ( + torch.stack(input_ts, dim=0), + torch.stack(input_padding, dim=0), + torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1), + ) + + def _postprocess_output( + self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1) + + mu, sigma = stats + return output_ts * sigma[:, None, None, None] + mu[:, None, None, None] + + def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + losses = [] + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[..., i] + loss = torch.max((q - 1) * errors, q * errors) + losses.append(loss.mean()) + return torch.stack(losses).mean() + + @can_return_tuple + @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TimesFmOutputForPrediction, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + past_values: Sequence[torch.Tensor], + freq: Optional[Sequence[Union[torch.Tensor, int]]] = None, + window_size: Optional[int] = None, + future_values: Optional[torch.Tensor] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TimesFmOutputForPrediction: + r""" + window_size (`int`, *optional*): + Window size of trend + residual decomposition. If None then we do not do decomposition. + future_values (`torch.Tensor`, *optional*): + Optional future time series values to be used for loss computation. + forecast_context_len (`int`, *optional*): + Optional max context length. + return_forecast_on_context (`bool`, *optional*): + True to return the forecast on the context when available, i.e. after the first input patch. + truncate_negative (`bool`, *optional*): + Truncate to only non-negative values if any of the contexts have non-negative values, + otherwise do nothing. + output_attentions (`bool`, *optional*): + Whether to output the attentions. + output_hidden_states (`bool`, *optional*): + Whether to output the hidden states. + + Returns: + A TimesFmOutputForPrediction object or a tuple containing: + - the mean forecast of size (# past_values, # forecast horizon), + - the full forecast (mean + quantiles) of size + (# past_values, # forecast horizon, 1 + # quantiles). + - loss: the mean squared error loss + quantile loss if `future_values` is provided. + """ + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + + # Get device from first input tensor + device = past_values[0].device + + # Truncate inputs to forecast_context_len + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs])) + + if window_size is not None: + new_inputs = [] + new_freqs = [] + for i, ts in enumerate(inputs): + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) + inputs = new_inputs + if freq is not None: + freq = new_freqs + + if freq is None: + logger.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + if output_attentions is None: + output_attentions = self.config.output_attentions + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) + # Move tensors to the same device as input + input_ts = input_ts.to(device) + input_padding = input_padding.to(device) + inp_freq = inp_freq.to(device) + + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + + if input_padding.shape[1] != final_out.shape[1] + self.horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}" + ) + output_patch_len = self.config.horizon_length + + num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = input_padding[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -fcontext_len:] + input_padding = current_padding[:, -fcontext_len:] + decoder_output = self.decoder( + past_values=input_ts, + past_values_padding=input_padding, + freq=inp_freq, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + fprop_outputs = self._postprocess_output( + decoder_output.last_hidden_state, + (decoder_output.loc, decoder_output.scale), + ) + + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :] + # We have to use reshape and not view for non-contiguous memory + new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3)) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = torch.concatenate([final_out, new_ts], axis=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = torch.concatenate(full_outputs, axis=1)[ + :, : (context_len - self.config.patch_length + self.horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :] + + mean_outputs = full_outputs[:, :, 0] + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = torch.maximum(mean_outputs, 0.0) + full_outputs = torch.maximum(full_outputs, 0.0) + + loss = None + if future_values is not None: + mse_loss = F.mse_loss(mean_outputs, future_values) + quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values) + loss = mse_loss + quantile_loss + + return TimesFmOutputForPrediction( + last_hidden_state=decoder_output.last_hidden_state, + attentions=decoder_output.attentions if output_attentions else None, + hidden_states=decoder_output.hidden_states if output_hidden_states else None, + mean_predictions=mean_outputs, + full_predictions=full_outputs, + loss=loss, + ) + + @staticmethod + def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]: + """Calculates the moving average using PyTorch's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + # Create a convolution kernel + kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size + # Apply convolution to calculate the moving average + smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() + return [smoothed_arr, arr - smoothed_arr] + + +__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/tests/models/timesfm/__init__.py b/tests/models/timesfm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/timesfm/test_modeling_timesfm.py b/tests/models/timesfm/test_modeling_timesfm.py new file mode 100644 index 00000000000..6d69a97d352 --- /dev/null +++ b/tests/models/timesfm/test_modeling_timesfm.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2025 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest +from typing import List + +import numpy as np +import torch + +from transformers import TimesFmConfig, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils import is_torch_fx_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin + + +if is_torch_fx_available(): + pass + +if is_torch_available(): + from transformers import TimesFmModelForPrediction + +TOLERANCE = 1e-4 + + +class TimesFmModelTester: + def __init__( + self, + parent, + patch_length: int = 32, + context_length: int = 512, + horizon_length: int = 128, + freq_size: int = 3, + num_hidden_layers: int = 1, + hidden_size: int = 16, + intermediate_size: int = 32, + head_dim: int = 8, + num_heads: int = 2, + tolerance: float = 1e-6, + rms_norm_eps: float = 1e-6, + quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + pad_val: float = 1123581321.0, + use_positional_embedding: bool = True, + initializer_factor: float = 0.0, + is_training: bool = False, + batch_size: int = 3, + ): + self.parent = parent + self.patch_length = patch_length + self.context_length = context_length + self.horizon_length = horizon_length + self.quantiles = quantiles + self.pad_val = pad_val + self.freq_size = freq_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_heads + self.tolerance = tolerance + self.rms_norm_eps = rms_norm_eps + self.use_positional_embedding = use_positional_embedding + self.initializer_factor = initializer_factor + self.is_training = is_training + self.batch_size = batch_size + + # The size of test input + self.seq_length = context_length // patch_length + self.hidden_size = hidden_size + + def get_config(self): + return TimesFmConfig( + patch_length=self.patch_length, + context_length=self.context_length, + horizon_length=self.horizon_length, + quantiles=self.quantiles, + pad_val=self.pad_val, + freq_size=self.freq_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + head_dim=self.head_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + tolerance=self.tolerance, + rms_norm_eps=self.rms_norm_eps, + use_positional_embedding=self.use_positional_embedding, + initializer_factor=self.initializer_factor, + ) + + def get_pipeline_config(self): + return self.get_config() + + def prepare_config_and_inputs(self): + forecast_input = [ + torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device), + torch.tensor(np.cos(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device), + torch.tensor(np.tan(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device), + ] + frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device) + + return (self.get_config(), torch.stack(forecast_input, dim=0), frequency_input) + + def prepare_config_and_inputs_for_common(self): + (config, forecast_input, frequency_input) = self.prepare_config_and_inputs() + + inputs_dict = { + "past_values": forecast_input, + "freq": frequency_input, + } + return config, inputs_dict + + +@require_torch +class TimesFmModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else () + all_generative_model_classes = () + all_parallelizable_model_classes = () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_model_parallel = False + is_encoder_decoder = False + test_inputs_embeds = False + + def setUp(self): + self.model_tester = TimesFmModelTester(self) + self.config_tester = ConfigTester(self, config_class=TimesFmConfig) + + def test_create_and_run_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = TimesFmModelForPrediction(config) + model.to(torch_device) + model.eval() + results = model(**inputs_dict) + assert results.mean_predictions is not None + + @unittest.skip(reason="Compile not yet supported because of masks") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip(reason="Model does not have input embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Model does not have head mask") + def test_headmasking(self): + pass + + # the main input name is `inputs` + def test_model_main_input_name(self): + model_signature = inspect.signature(getattr(TimesFmModelForPrediction, "forward")) + # The main input is the name of the argument after `self` + observed_main_input_name = list(model_signature.parameters.keys())[1] + self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name) + + +@require_torch +@slow +class TimesFmModelIntegrationTests(unittest.TestCase): + def test_inference_no_head(self): + model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch", revision="refs/pr/7").to( + torch_device + ) + forecast_input = [ + np.sin(np.linspace(0, 20, 100)), + np.sin(np.linspace(0, 20, 200)), + np.sin(np.linspace(0, 20, 400)), + ] + forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32, device=torch_device) for ts in forecast_input] + frequency_input = [0, 1, 2] + + with torch.no_grad(): + output = model(past_values=forecast_input_tensor, freq=frequency_input).last_hidden_state + + self.assertEqual( + output.shape, + torch.Size([3, model.config.context_length // model.config.patch_length, model.config.hidden_size]), + ) + expected_slice = torch.tensor( + [[-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932]], + device=torch_device, + ) + self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE)) diff --git a/utils/check_repo.py b/utils/check_repo.py index aeba4ee73de..85178b663e4 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -156,6 +156,7 @@ IGNORE_NON_TESTED = ( "Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests "Emu3VQVAE", # Building part of bigger (tested) model "Emu3TextModel", # Building part of bigger (tested) model + "TimesFmModel", # Building part of bigger (tested) model ] )