mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
added initial TimesFMModelIntegrationTests
This commit is contained in:
parent
9aad1013f0
commit
be8922fa97
@ -851,7 +851,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||
- the number of padded examples for SPMD so that each core has the same
|
||||
number (a multiple of `batch_size`) of examples.
|
||||
"""
|
||||
input_ts, input_padding = [], []
|
||||
input_ts, input_padding, inp_freq = [], [], []
|
||||
|
||||
for i, ts in enumerate(inputs):
|
||||
input_len = ts.shape[0]
|
||||
@ -866,11 +866,12 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||
|
||||
input_ts.append(ts)
|
||||
input_padding.append(padding)
|
||||
inp_freq.append(freq[i])
|
||||
|
||||
return (
|
||||
torch.stack(input_ts, dim=0),
|
||||
torch.stack(input_padding, dim=0),
|
||||
torch.tensor(freq, dtype=torch.int32, device=input_ts[0].device).reshape(-1, 1),
|
||||
torch.tensor(inp_freq, dtype=torch.int32).reshape(-1, 1),
|
||||
)
|
||||
|
||||
def _postprocess_output(
|
||||
@ -990,8 +991,8 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||
|
||||
def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
losses = []
|
||||
for i, q in enumerate(self.config.quantiles):
|
||||
errors = targets - predictions[:, :, i]
|
||||
for q in self.config.quantiles:
|
||||
errors = targets - predictions
|
||||
loss = torch.max((q - 1) * errors, q * errors)
|
||||
losses.append(loss.mean())
|
||||
return torch.stack(losses).mean()
|
||||
@ -1073,6 +1074,11 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||
|
||||
input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
|
||||
|
||||
# Move tensors to the same device as input
|
||||
input_ts = input_ts.to(device)
|
||||
input_padding = input_padding.to(device)
|
||||
inp_freq = inp_freq.to(device)
|
||||
|
||||
mean_outputs, full_outputs, last_hidden_state, all_attentions, all_hidden_states = self.decode(
|
||||
input_ts=input_ts,
|
||||
paddings=input_padding,
|
||||
|
@ -20,8 +20,9 @@ from typing import List
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import TimesFMConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.utils import is_torch_fx_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -32,7 +33,9 @@ if is_torch_fx_available():
|
||||
pass
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import TimesFMModelForPrediction
|
||||
from transformers import TimesFMDecoder, TimesFMModelForPrediction
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
class TimesFMModelTester:
|
||||
@ -46,7 +49,7 @@ class TimesFMModelTester:
|
||||
num_layers: int = 1,
|
||||
model_dim: int = 16,
|
||||
intermediate_size: int = 32,
|
||||
head_dim: int = 2,
|
||||
head_dim: int = 8,
|
||||
num_heads: int = 2,
|
||||
tolerance: float = 1e-6,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
@ -163,3 +166,30 @@ class TimesFMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# The main input is the name of the argument after `self`
|
||||
observed_main_input_name = list(model_signature.parameters.keys())[1]
|
||||
self.assertEqual(TimesFMModelForPrediction.main_input_name, observed_main_input_name)
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
class TimesFMModelIntegrationTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def load_batch(cls, filename="train-batch.pt"):
|
||||
file = hf_hub_download(
|
||||
repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset"
|
||||
)
|
||||
batch = torch.load(file, map_location=torch_device)
|
||||
return batch
|
||||
|
||||
def test_inference_no_head(self):
|
||||
model = TimesFMModelForPrediction.from_pretrained("huggingface/timesfm-tourism-monthly").to(torch_device)
|
||||
batch = self.load_batch()
|
||||
with torch.no_grad():
|
||||
inputs = batch["past_values"]
|
||||
output = model(inputs=inputs).last_hidden_state
|
||||
self.assertEqual(
|
||||
output.shape, torch.Size([64, model.config.context_len // model.config.patch_len, model.config.model_dim])
|
||||
)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-4.0141, 3.3141, 1.9321], [-4.9121, 3.1443, 2.0836], [-5.1142, 2.7376, 2.1566]], device=torch_device
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
Loading…
Reference in New Issue
Block a user