transformers/tests/models/timesfm/test_modeling_timesfm.py
Jinan Zhou a91020aed0
Add TimesFM Time Series Forecasting Model (#34082)
* initial documentation

* rename mask to attention_mask

* smaller tests

* fixup

* fix copies

* move to time series section

* sort docs

* isort fix

* batch_size is not a configuration

* rename to TimesFMModelForPrediction

* initial script

* add check_outputs

* remove dropout_rate

* works with torch.Tensor inputs

* rename script

* fix docstrings

* fix freq when window_size is given

* add loss

* fix _quantile_loss

* formatting

* fix isort

* add weight init

* add support for sdpa and flash_attention_2

* fixes for flash_attention

* formatting

* remove flash_attention

* fix tests

* fix file name

* fix quantile loss

* added initial TimesFMModelIntegrationTests

* fix formatting

* fix import order

* fix _quantile_loss

* add doc for SDPA

* use timesfm 2.0

* bug fix in timesfm decode function.

* compare mean forecasts

* refactor type hints, use CamelCase

* consolidate decode func

* more readable code for weight conversion

* fix-copies

* simpler init

* renaem TimesFmMLP

* use T5LayerNorm

* fix tests

* use initializer_range

* TimesFmModel instead of TimesFmDecoder

* TimesFmPositionalEmbedding takes config for its init

* 2.0-500m-pytorch default configs

* use TimesFmModel

* fix formatting

* ignore TimesFmModel for testing

* fix docstring

* override generate as its not needed

* add doc strings

* fix logging

* add docstrings to output data classes

* initial copy from t5

* added config and attention layers

* add TimesFMPositionalEmbedding

* calcuate scale_factor once

* add more configs and TimesFMResidualBlock

* fix input_dims

* standardize code format with black

* remove unneeded modules

* TimesFM Model

* order of imports

* copy from Google official implementation

* remove covariate forecasting

* Adapting TimesFM to HF format

* restructing in progress

* adapted to HF convention

* timesfm test

* the model runs

* fixing unit tests

* fixing unit tests in progress

* add post_init

* do not change TimesFMOutput

* fixing unit tests

* all unit tests passed

* remove timesfm_layers

* add intermediate_size and initialize with config

* initial documentation

* rename mask to attention_mask

* smaller tests

* fixup

* fix copies

* move to time series section

* sort docs

* isort fix

* batch_size is not a configuration

* rename to TimesFMModelForPrediction

* initial script

* add check_outputs

* remove dropout_rate

* works with torch.Tensor inputs

* rename script

* fix docstrings

* fix freq when window_size is given

* add loss

* fix _quantile_loss

* formatting

* fix isort

* add weight init

* add support for sdpa and flash_attention_2

* fixes for flash_attention

* formatting

* remove flash_attention

* fix tests

* fix file name

* fix quantile loss

* added initial TimesFMModelIntegrationTests

* fix formatting

* fix import order

* fix _quantile_loss

* add doc for SDPA

* use timesfm 2.0

* bug fix in timesfm decode function.

* compare mean forecasts

* refactor type hints, use CamelCase

* consolidate decode func

* more readable code for weight conversion

* fix-copies

* simpler init

* renaem TimesFmMLP

* use T5LayerNorm

* fix tests

* use initializer_range

* TimesFmModel instead of TimesFmDecoder

* TimesFmPositionalEmbedding takes config for its init

* 2.0-500m-pytorch default configs

* use TimesFmModel

* fix formatting

* ignore TimesFmModel for testing

* fix docstring

* override generate as its not needed

* add doc strings

* fix logging

* add docstrings to output data classes

* add _CHECKPOINT_FOR_DOC

* fix comments

* Revert "fix comments"

This reverts commit 8deeb3e191.

* add _prepare_4d_attention_mask

* we do not have generative model classes

* use Cache

* return past_key_values

* modules initialized with config only

* update year

* Update docs/source/en/model_doc/timesfm.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add layer_idx to cache

* modular timesfm

* fix test

* unwrap sequential class

* fix toctree

* remove TimesFmOnnxConfig

* fix modular

* remove TimesFmStackedDecoder

* split qkv layer into individual layers

* rename projection layers

* use ALL_ATTENTION_FUNCTIONS

* is_causal is True

* rename config

* does not support flash_attn_2

* formatting

* fix typo in docsstring

* rename inputs

* add time series mapping

* Update src/transformers/models/olmo2/modeling_olmo2.py

* Update src/transformers/models/moonshine/modeling_moonshine.py

* use updated arguments

* fix class name

* add MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING

* isort

* consolidate _preprocess into forward

* fix a typo

* fix a typo

* fix toc

* fix modular

* remove aaserts

* use self.config._attn_implementation

* move to _postprocess_output

* remove timesfm_get_large_negative_number

* use view unstead of multiple unsqueeze

* make helpers static methods of the Model

* use to_tuple

* use to_tuple if not return_dict

* remove unused intitialization block as its incorporated in nn.Linear

* remove unused num_key_value_groups

* use the same convention as the masking method

* update modular

* do not use unsqueeze

* use view instead of unsqueeze

* use buffer for inv_timescales

* formatting

* modular conversion

* remove unneeded intialization

* add missing docstrings

* remove cache

* use simple_eager_attention_forward

* support tp_plan

* support for flex and flash attention masks

* Revert "support for flex and flash attention masks"

This reverts commit def36c4fcf.

* fix device

* fix tests on gpu

* remove unsued large model test

* removed unneeded comments

* add example usage

* fix style

* add import

* Update docs/source/en/model_doc/timesfm.md

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>

* inherit from LlamaRMSNorm

* use can_return_tuple decorator

* remvoe return_dict

* fix year

* Update docs/source/en/model_doc/timesfm.md

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>

* pretrained does not inherit from GenerationMixin

* use model for integration test

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Rajat Sen <rsen91@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
2025-04-16 15:00:53 +02:00

198 lines
7.1 KiB
Python

# coding=utf-8
# Copyright 2025 Google LLC and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
from typing import List
import numpy as np
import torch
from transformers import TimesFmConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_fx_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin
if is_torch_fx_available():
pass
if is_torch_available():
from transformers import TimesFmModelForPrediction
TOLERANCE = 1e-4
class TimesFmModelTester:
def __init__(
self,
parent,
patch_length: int = 32,
context_length: int = 512,
horizon_length: int = 128,
freq_size: int = 3,
num_hidden_layers: int = 1,
hidden_size: int = 16,
intermediate_size: int = 32,
head_dim: int = 8,
num_heads: int = 2,
tolerance: float = 1e-6,
rms_norm_eps: float = 1e-6,
quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
pad_val: float = 1123581321.0,
use_positional_embedding: bool = True,
initializer_factor: float = 0.0,
is_training: bool = False,
batch_size: int = 3,
):
self.parent = parent
self.patch_length = patch_length
self.context_length = context_length
self.horizon_length = horizon_length
self.quantiles = quantiles
self.pad_val = pad_val
self.freq_size = freq_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_heads
self.tolerance = tolerance
self.rms_norm_eps = rms_norm_eps
self.use_positional_embedding = use_positional_embedding
self.initializer_factor = initializer_factor
self.is_training = is_training
self.batch_size = batch_size
# The size of test input
self.seq_length = context_length // patch_length
self.hidden_size = hidden_size
def get_config(self):
return TimesFmConfig(
patch_length=self.patch_length,
context_length=self.context_length,
horizon_length=self.horizon_length,
quantiles=self.quantiles,
pad_val=self.pad_val,
freq_size=self.freq_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
head_dim=self.head_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
tolerance=self.tolerance,
rms_norm_eps=self.rms_norm_eps,
use_positional_embedding=self.use_positional_embedding,
initializer_factor=self.initializer_factor,
)
def get_pipeline_config(self):
return self.get_config()
def prepare_config_and_inputs(self):
forecast_input = [
torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
torch.tensor(np.cos(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
torch.tensor(np.tan(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
]
frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device)
return (self.get_config(), torch.stack(forecast_input, dim=0), frequency_input)
def prepare_config_and_inputs_for_common(self):
(config, forecast_input, frequency_input) = self.prepare_config_and_inputs()
inputs_dict = {
"past_values": forecast_input,
"freq": frequency_input,
}
return config, inputs_dict
@require_torch
class TimesFmModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else ()
all_generative_model_classes = ()
all_parallelizable_model_classes = ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_model_parallel = False
is_encoder_decoder = False
test_inputs_embeds = False
def setUp(self):
self.model_tester = TimesFmModelTester(self)
self.config_tester = ConfigTester(self, config_class=TimesFmConfig)
def test_create_and_run_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = TimesFmModelForPrediction(config)
model.to(torch_device)
model.eval()
results = model(**inputs_dict)
assert results.mean_predictions is not None
@unittest.skip(reason="Compile not yet supported because of masks")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Model does not have input embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Model does not have head mask")
def test_headmasking(self):
pass
# the main input name is `inputs`
def test_model_main_input_name(self):
model_signature = inspect.signature(getattr(TimesFmModelForPrediction, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name)
@require_torch
@slow
class TimesFmModelIntegrationTests(unittest.TestCase):
def test_inference_no_head(self):
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch", revision="refs/pr/7").to(
torch_device
)
forecast_input = [
np.sin(np.linspace(0, 20, 100)),
np.sin(np.linspace(0, 20, 200)),
np.sin(np.linspace(0, 20, 400)),
]
forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32, device=torch_device) for ts in forecast_input]
frequency_input = [0, 1, 2]
with torch.no_grad():
output = model(past_values=forecast_input_tensor, freq=frequency_input).last_hidden_state
self.assertEqual(
output.shape,
torch.Size([3, model.config.context_length // model.config.patch_length, model.config.hidden_size]),
)
expected_slice = torch.tensor(
[[-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932]],
device=torch_device,
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))