[TimesFM] use the main revison instead of revision for integration test (#37558)

* use the main revison instead of revision

* test prediction

* check larger time steps
This commit is contained in:
Kashif Rasul 2025-04-17 12:26:03 +03:00 committed by GitHub
parent 3bc44eaaee
commit dc06e7cecd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 13 deletions

View File

@ -48,6 +48,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google
num_layers=50,
model_dims=1280,
use_positional_embedding=False,
context_len=2048,
),
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
)
@ -159,6 +160,7 @@ def check_outputs(model_path, huggingface_repo_id):
input_patch_len=32,
output_patch_len=128,
num_layers=50,
context_len=2048,
model_dims=1280,
use_positional_embedding=False,
point_forecast_mode="mean",

View File

@ -171,10 +171,8 @@ class TimesFmModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@slow
class TimesFmModelIntegrationTests(unittest.TestCase):
def test_inference_no_head(self):
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch", revision="refs/pr/7").to(
torch_device
)
def test_inference(self):
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch").to(torch_device)
forecast_input = [
np.sin(np.linspace(0, 20, 100)),
np.sin(np.linspace(0, 20, 200)),
@ -184,14 +182,21 @@ class TimesFmModelIntegrationTests(unittest.TestCase):
frequency_input = [0, 1, 2]
with torch.no_grad():
output = model(past_values=forecast_input_tensor, freq=frequency_input).last_hidden_state
output = model(past_values=forecast_input_tensor, freq=frequency_input)
self.assertEqual(
output.shape,
torch.Size([3, model.config.context_length // model.config.patch_length, model.config.hidden_size]),
)
mean_predictions = output.mean_predictions
self.assertEqual(mean_predictions.shape, torch.Size([3, model.config.horizon_length]))
# fmt: off
expected_slice = torch.tensor(
[[-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932]],
device=torch_device,
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
[ 0.9813, 1.0086, 0.9985, 0.9432, 0.8505, 0.7203, 0.5596, 0.3788,
0.1796, -0.0264, -0.2307, -0.4255, -0.5978, -0.7642, -0.8772, -0.9670,
-1.0110, -1.0162, -0.9848, -0.9151, -0.8016, -0.6511, -0.4707, -0.2842,
-0.0787, 0.1260, 0.3293, 0.5104, 0.6818, 0.8155, 0.9172, 0.9843,
1.0101, 1.0025, 0.9529, 0.8588, 0.7384, 0.5885, 0.4022, 0.2099,
-0.0035, -0.2104, -0.4146, -0.6033, -0.7661, -0.8818, -0.9725, -1.0191,
-1.0190, -0.9874, -0.9137, -0.8069, -0.6683, -0.4939, -0.3086, -0.1106,
0.0846, 0.2927, 0.4832, 0.6612, 0.8031, 0.9051, 0.9772, 1.0064
],
device=torch_device)
# fmt: on
self.assertTrue(torch.allclose(mean_predictions[0, :64], expected_slice, atol=TOLERANCE))