mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00

* initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * initial copy from t5 * added config and attention layers * add TimesFMPositionalEmbedding * calcuate scale_factor once * add more configs and TimesFMResidualBlock * fix input_dims * standardize code format with black * remove unneeded modules * TimesFM Model * order of imports * copy from Google official implementation * remove covariate forecasting * Adapting TimesFM to HF format * restructing in progress * adapted to HF convention * timesfm test * the model runs * fixing unit tests * fixing unit tests in progress * add post_init * do not change TimesFMOutput * fixing unit tests * all unit tests passed * remove timesfm_layers * add intermediate_size and initialize with config * initial documentation * rename mask to attention_mask * smaller tests * fixup * fix copies * move to time series section * sort docs * isort fix * batch_size is not a configuration * rename to TimesFMModelForPrediction * initial script * add check_outputs * remove dropout_rate * works with torch.Tensor inputs * rename script * fix docstrings * fix freq when window_size is given * add loss * fix _quantile_loss * formatting * fix isort * add weight init * add support for sdpa and flash_attention_2 * fixes for flash_attention * formatting * remove flash_attention * fix tests * fix file name * fix quantile loss * added initial TimesFMModelIntegrationTests * fix formatting * fix import order * fix _quantile_loss * add doc for SDPA * use timesfm 2.0 * bug fix in timesfm decode function. * compare mean forecasts * refactor type hints, use CamelCase * consolidate decode func * more readable code for weight conversion * fix-copies * simpler init * renaem TimesFmMLP * use T5LayerNorm * fix tests * use initializer_range * TimesFmModel instead of TimesFmDecoder * TimesFmPositionalEmbedding takes config for its init * 2.0-500m-pytorch default configs * use TimesFmModel * fix formatting * ignore TimesFmModel for testing * fix docstring * override generate as its not needed * add doc strings * fix logging * add docstrings to output data classes * add _CHECKPOINT_FOR_DOC * fix comments * Revert "fix comments" This reverts commit8deeb3e191
. * add _prepare_4d_attention_mask * we do not have generative model classes * use Cache * return past_key_values * modules initialized with config only * update year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add layer_idx to cache * modular timesfm * fix test * unwrap sequential class * fix toctree * remove TimesFmOnnxConfig * fix modular * remove TimesFmStackedDecoder * split qkv layer into individual layers * rename projection layers * use ALL_ATTENTION_FUNCTIONS * is_causal is True * rename config * does not support flash_attn_2 * formatting * fix typo in docsstring * rename inputs * add time series mapping * Update src/transformers/models/olmo2/modeling_olmo2.py * Update src/transformers/models/moonshine/modeling_moonshine.py * use updated arguments * fix class name * add MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING * isort * consolidate _preprocess into forward * fix a typo * fix a typo * fix toc * fix modular * remove aaserts * use self.config._attn_implementation * move to _postprocess_output * remove timesfm_get_large_negative_number * use view unstead of multiple unsqueeze * make helpers static methods of the Model * use to_tuple * use to_tuple if not return_dict * remove unused intitialization block as its incorporated in nn.Linear * remove unused num_key_value_groups * use the same convention as the masking method * update modular * do not use unsqueeze * use view instead of unsqueeze * use buffer for inv_timescales * formatting * modular conversion * remove unneeded intialization * add missing docstrings * remove cache * use simple_eager_attention_forward * support tp_plan * support for flex and flash attention masks * Revert "support for flex and flash attention masks" This reverts commitdef36c4fcf
. * fix device * fix tests on gpu * remove unsued large model test * removed unneeded comments * add example usage * fix style * add import * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * inherit from LlamaRMSNorm * use can_return_tuple decorator * remvoe return_dict * fix year * Update docs/source/en/model_doc/timesfm.md Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * pretrained does not inherit from GenerationMixin * use model for integration test --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Rajat Sen <rsen91@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
198 lines
7.1 KiB
Python
198 lines
7.1 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 Google LLC and HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
import unittest
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from transformers import TimesFmConfig, is_torch_available
|
|
from transformers.testing_utils import require_torch, slow, torch_device
|
|
from transformers.utils import is_torch_fx_available
|
|
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin
|
|
|
|
|
|
if is_torch_fx_available():
|
|
pass
|
|
|
|
if is_torch_available():
|
|
from transformers import TimesFmModelForPrediction
|
|
|
|
TOLERANCE = 1e-4
|
|
|
|
|
|
class TimesFmModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
patch_length: int = 32,
|
|
context_length: int = 512,
|
|
horizon_length: int = 128,
|
|
freq_size: int = 3,
|
|
num_hidden_layers: int = 1,
|
|
hidden_size: int = 16,
|
|
intermediate_size: int = 32,
|
|
head_dim: int = 8,
|
|
num_heads: int = 2,
|
|
tolerance: float = 1e-6,
|
|
rms_norm_eps: float = 1e-6,
|
|
quantiles: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
|
pad_val: float = 1123581321.0,
|
|
use_positional_embedding: bool = True,
|
|
initializer_factor: float = 0.0,
|
|
is_training: bool = False,
|
|
batch_size: int = 3,
|
|
):
|
|
self.parent = parent
|
|
self.patch_length = patch_length
|
|
self.context_length = context_length
|
|
self.horizon_length = horizon_length
|
|
self.quantiles = quantiles
|
|
self.pad_val = pad_val
|
|
self.freq_size = freq_size
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.head_dim = head_dim
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_heads
|
|
self.tolerance = tolerance
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.use_positional_embedding = use_positional_embedding
|
|
self.initializer_factor = initializer_factor
|
|
self.is_training = is_training
|
|
self.batch_size = batch_size
|
|
|
|
# The size of test input
|
|
self.seq_length = context_length // patch_length
|
|
self.hidden_size = hidden_size
|
|
|
|
def get_config(self):
|
|
return TimesFmConfig(
|
|
patch_length=self.patch_length,
|
|
context_length=self.context_length,
|
|
horizon_length=self.horizon_length,
|
|
quantiles=self.quantiles,
|
|
pad_val=self.pad_val,
|
|
freq_size=self.freq_size,
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=self.intermediate_size,
|
|
head_dim=self.head_dim,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
tolerance=self.tolerance,
|
|
rms_norm_eps=self.rms_norm_eps,
|
|
use_positional_embedding=self.use_positional_embedding,
|
|
initializer_factor=self.initializer_factor,
|
|
)
|
|
|
|
def get_pipeline_config(self):
|
|
return self.get_config()
|
|
|
|
def prepare_config_and_inputs(self):
|
|
forecast_input = [
|
|
torch.tensor(np.sin(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
|
|
torch.tensor(np.cos(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
|
|
torch.tensor(np.tan(np.linspace(0, 20, 100)), dtype=torch.float32, device=torch_device),
|
|
]
|
|
frequency_input = torch.tensor([0, 1, 2], dtype=torch.long, device=torch_device)
|
|
|
|
return (self.get_config(), torch.stack(forecast_input, dim=0), frequency_input)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
(config, forecast_input, frequency_input) = self.prepare_config_and_inputs()
|
|
|
|
inputs_dict = {
|
|
"past_values": forecast_input,
|
|
"freq": frequency_input,
|
|
}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class TimesFmModelTest(ModelTesterMixin, unittest.TestCase):
|
|
all_model_classes = (TimesFmModelForPrediction,) if is_torch_available() else ()
|
|
all_generative_model_classes = ()
|
|
all_parallelizable_model_classes = ()
|
|
fx_compatible = False
|
|
test_pruning = False
|
|
test_resize_embeddings = False
|
|
test_model_parallel = False
|
|
is_encoder_decoder = False
|
|
test_inputs_embeds = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = TimesFmModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=TimesFmConfig)
|
|
|
|
def test_create_and_run_model(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
model = TimesFmModelForPrediction(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
results = model(**inputs_dict)
|
|
assert results.mean_predictions is not None
|
|
|
|
@unittest.skip(reason="Compile not yet supported because of masks")
|
|
def test_sdpa_can_dispatch_on_flash(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="Model does not have input embeddings")
|
|
def test_model_get_set_embeddings(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="Model does not have head mask")
|
|
def test_headmasking(self):
|
|
pass
|
|
|
|
# the main input name is `inputs`
|
|
def test_model_main_input_name(self):
|
|
model_signature = inspect.signature(getattr(TimesFmModelForPrediction, "forward"))
|
|
# The main input is the name of the argument after `self`
|
|
observed_main_input_name = list(model_signature.parameters.keys())[1]
|
|
self.assertEqual(TimesFmModelForPrediction.main_input_name, observed_main_input_name)
|
|
|
|
|
|
@require_torch
|
|
@slow
|
|
class TimesFmModelIntegrationTests(unittest.TestCase):
|
|
def test_inference_no_head(self):
|
|
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch", revision="refs/pr/7").to(
|
|
torch_device
|
|
)
|
|
forecast_input = [
|
|
np.sin(np.linspace(0, 20, 100)),
|
|
np.sin(np.linspace(0, 20, 200)),
|
|
np.sin(np.linspace(0, 20, 400)),
|
|
]
|
|
forecast_input_tensor = [torch.tensor(ts, dtype=torch.float32, device=torch_device) for ts in forecast_input]
|
|
frequency_input = [0, 1, 2]
|
|
|
|
with torch.no_grad():
|
|
output = model(past_values=forecast_input_tensor, freq=frequency_input).last_hidden_state
|
|
|
|
self.assertEqual(
|
|
output.shape,
|
|
torch.Size([3, model.config.context_length // model.config.patch_length, model.config.hidden_size]),
|
|
)
|
|
expected_slice = torch.tensor(
|
|
[[-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932]],
|
|
device=torch_device,
|
|
)
|
|
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
|