transformers/tests/models/timesfm/test_modeling_timesfm.py
Matt 508a704055
No more Tuple, List, Dict (#38797)
* No more Tuple, List, Dict

* make fixup

* More style fixes

* Docstring fixes with regex replacement

* Trigger tests

* Redo fixes after rebase

* Fix copies

* [test all]

* update

* [test all]

* update

* [test all]

* make style after rebase

* Patch the hf_argparser test

* Patch the hf_argparser test

* style fixes

* style fixes

* style fixes

* Fix docstrings in Cohere test

* [test all]

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-06-17 19:37:18 +01:00

202 lines
7.6 KiB
Python

# coding=utf-8
# Copyright 2025 Google LLC and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from transformers import TimesFmConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_fx_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin
if is_torch_fx_available():
pass
if is_torch_available():
from transformers import TimesFmModelForPrediction
TOLERANCE = 1e-4
class TimesFmModelTester:
def __init__(
self,
parent,
patch_length: int = 32,
context_length: int = 512,
horizon_length: int = 128,
freq_size: int = 3,
num_hidden_layers: int = 1,
hidden_size: int = 16,
intermediate_size: int = 32,
head_dim: int = 8,
num_heads: int = 2,
tolerance: float = 1e-6,
rms_norm_eps: float = 1e-6,
quantiles: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
pad_val: float = 1123581321.0,
use_positional_embedding: bool = True,
initializer_factor: float = 0.0,
is_training: bool = False,
batch_size: int = 3,
):
self.parent = parent
self.patch_length = patch_length
self.context_length = context_length
self.horizon_length = horizon_length
self.quantiles = quantiles
self.pad_val = pad_val
self.freq_size = freq_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_heads
self.tolerance = tolerance
self.rms_norm_eps = rms_norm_eps
self.use_positional_embedding = use_positional_embedding
self.initializer_factor = initializer_factor
self.is_training = is_training
self.batch_size = batch_size
# The size of test input
self.seq_length = context_length // patch_length
self.hidden_size = hidden_size
def get_config(self):
return TimesFmConfig(
patch_length=self.patch_length,
context_length=self.context_length,
horizon_length=self.horizon_length,
quantiles=self.quantiles,
pad_val=self.pad_val,
freq_size=self.freq_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
head_dim=self.head_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
tolerance=self.tolerance,
rms_norm_eps=self.rms_norm_eps,
use_positional_embedding=self.use_positional_embedding,
initializer_factor=self.initializer_factor,
)
def get_pipeline_config(self):
return self.get_config()
def prepare_config_and_inputs(self):
forecast_input = [
torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
torch.tensor(np.cos(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
torch.tensor(np.tan(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
]
frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device)
return (self.get_config(), torch.stack(forecast_input, dim=0), frequency_input)
def prepare_config_and_inputs_for_common(self):
(config, forecast_input, frequency_input) = self.prepare_config_and_inputs()
inputs_dict = {
"past_values": forecast_input,
"freq": frequency_input,
}
return config, inputs_dict
@require_torch
class TimesFmModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else ()
all_generative_model_classes = ()
all_parallelizable_model_classes = ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_model_parallel = False
is_encoder_decoder = False
test_inputs_embeds = False
def setUp(self):
self.model_tester = TimesFmModelTester(self)
self.config_tester = ConfigTester(self, config_class=TimesFmConfig)
def test_create_and_run_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = TimesFmModelForPrediction(config)
model.to(torch_device)
model.eval()
results = model(**inputs_dict)
assert results.mean_predictions is not None
@unittest.skip(reason="Compile not yet supported because of masks")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Model does not have input embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Model does not have head mask")
def test_headmasking(self):
pass
# the main input name is `inputs`
def test_model_main_input_name(self):
model_signature = inspect.signature(getattr(TimesFmModelForPrediction, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name)
@require_torch
@slow
class TimesFmModelIntegrationTests(unittest.TestCase):
def test_inference(self):
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch").to(torch_device)
forecast_input = [
np.sin(np.linspace(0, 20, 100)),
np.sin(np.linspace(0, 20, 200)),
np.sin(np.linspace(0, 20, 400)),
]
forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32, device=torch_device) for ts in forecast_input]
frequency_input = [0, 1, 2]
with torch.no_grad():
output = model(past_values=forecast_input_tensor, freq=frequency_input)
mean_predictions = output.mean_predictions
self.assertEqual(mean_predictions.shape, torch.Size([3, model.config.horizon_length]))
# fmt: off
expected_slice = torch.tensor(
[ 0.9813, 1.0086, 0.9985, 0.9432, 0.8505, 0.7203, 0.5596, 0.3788,
0.1796, -0.0264, -0.2307, -0.4255, -0.5978, -0.7642, -0.8772, -0.9670,
-1.0110, -1.0162, -0.9848, -0.9151, -0.8016, -0.6511, -0.4707, -0.2842,
-0.0787, 0.1260, 0.3293, 0.5104, 0.6818, 0.8155, 0.9172, 0.9843,
1.0101, 1.0025, 0.9529, 0.8588, 0.7384, 0.5885, 0.4022, 0.2099,
-0.0035, -0.2104, -0.4146, -0.6033, -0.7661, -0.8818, -0.9725, -1.0191,
-1.0190, -0.9874, -0.9137, -0.8069, -0.6683, -0.4939, -0.3086, -0.1106,
0.0846, 0.2927, 0.4832, 0.6612, 0.8031, 0.9051, 0.9772, 1.0064
],
device=torch_device)
# fmt: on
self.assertTrue(torch.allclose(mean_predictions[0, :64], expected_slice, atol=TOLERANCE))