added initial TimesFMModelIntegrationTests

This commit is contained in:
Kashif Rasul 2024-12-09 14:26:52 +01:00 committed by Jinan Zhou
parent 9aad1013f0
commit be8922fa97
2 changed files with 43 additions and 7 deletions

View File

@ -851,7 +851,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
- the number of padded examples for SPMD so that each core has the same
number (a multiple of `batch_size`) of examples.
"""
input_ts, input_padding = [], []
input_ts, input_padding, inp_freq = [], [], []
for i, ts in enumerate(inputs):
input_len = ts.shape[0]
@ -866,11 +866,12 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
input_ts.append(ts)
input_padding.append(padding)
inp_freq.append(freq[i])
return (
torch.stack(input_ts, dim=0),
torch.stack(input_padding, dim=0),
torch.tensor(freq, dtype=torch.int32, device=input_ts[0].device).reshape(-1, 1),
torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
)
def _postprocess_output(
@ -990,8 +991,8 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
losses = []
for i, q in enumerate(self.config.quantiles):
errors = targets - predictions[:, :, i]
for q in self.config.quantiles:
errors = targets - predictions
loss = torch.max((q - 1) * errors, q * errors)
losses.append(loss.mean())
return torch.stack(losses).mean()
@ -1073,6 +1074,11 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
# Move tensors to the same device as input
input_ts = input_ts.to(device)
input_padding = input_padding.to(device)
inp_freq = inp_freq.to(device)
mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode(
input_ts=input_ts,
paddings=input_padding,

View File

@ -20,8 +20,9 @@ from typing import List
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import TimesFMConfig, is_torch_available
from transformers.testing_utils import require_torch, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_fx_available
from ...test_configuration_common import ConfigTester
@ -32,7 +33,9 @@ if is_torch_fx_available():
pass
if is_torch_available():
from transformers import TimesFMModelForPrediction
from transformers import TimesFMDecoder, TimesFMModelForPrediction
TOLERANCE = 1e-4
class TimesFMModelTester:
@ -46,7 +49,7 @@ class TimesFMModelTester:
num_layers: int = 1,
model_dim: int = 16,
intermediate_size: int = 32,
head_dim: int = 2,
head_dim: int = 8,
num_heads: int = 2,
tolerance: float = 1e-6,
rms_norm_eps: float = 1e-6,
@ -163,3 +166,30 @@ class TimesFMModelTest(ModelTesterMixin, unittest.TestCase):
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(TimesFMModelForPrediction.main_input_name, observed_main_input_name)
@require_torch
@slow
class TimesFMModelIntegrationTests(unittest.TestCase):
@classmethod
def load_batch(cls, filename="train-batch.pt"):
file = hf_hub_download(
repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset"
)
batch = torch.load(file, map_location=torch_device)
return batch
def test_inference_no_head(self):
model = TimesFMModelForPrediction.from_pretrained("huggingface/timesfm-tourism-monthly").to(torch_device)
batch = self.load_batch()
with torch.no_grad():
inputs = batch["past_values"]
output = model(inputs=inputs).last_hidden_state
self.assertEqual(
output.shape, torch.Size([64, model.config.context_len // model.config.patch_len, model.config.model_dim])
)
expected_slice = torch.tensor(
[[-4.0141, 3.3141, 1.9321], [-4.9121, 3.1443, 2.0836], [-5.1142, 2.7376, 2.1566]], device=torch_device
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))