mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[TimesFM] use the main revison instead of revision for integration test (#37558)
* use the main revison instead of revision * test prediction * check larger time steps
This commit is contained in:
parent
3bc44eaaee
commit
dc06e7cecd
@ -48,6 +48,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google
|
||||
num_layers=50,
|
||||
model_dims=1280,
|
||||
use_positional_embedding=False,
|
||||
context_len=2048,
|
||||
),
|
||||
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
|
||||
)
|
||||
@ -159,6 +160,7 @@ def check_outputs(model_path, huggingface_repo_id):
|
||||
input_patch_len=32,
|
||||
output_patch_len=128,
|
||||
num_layers=50,
|
||||
context_len=2048,
|
||||
model_dims=1280,
|
||||
use_positional_embedding=False,
|
||||
point_forecast_mode="mean",
|
||||
|
@ -171,10 +171,8 @@ class TimesFmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
@slow
|
||||
class TimesFmModelIntegrationTests(unittest.TestCase):
|
||||
def test_inference_no_head(self):
|
||||
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch", revision="refs/pr/7").to(
|
||||
torch_device
|
||||
)
|
||||
def test_inference(self):
|
||||
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch").to(torch_device)
|
||||
forecast_input = [
|
||||
np.sin(np.linspace(0, 20, 100)),
|
||||
np.sin(np.linspace(0, 20, 200)),
|
||||
@ -184,14 +182,21 @@ class TimesFmModelIntegrationTests(unittest.TestCase):
|
||||
frequency_input = [0, 1, 2]
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(past_values=forecast_input_tensor, freq=frequency_input).last_hidden_state
|
||||
output = model(past_values=forecast_input_tensor, freq=frequency_input)
|
||||
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
torch.Size([3, model.config.context_length // model.config.patch_length, model.config.hidden_size]),
|
||||
)
|
||||
mean_predictions = output.mean_predictions
|
||||
self.assertEqual(mean_predictions.shape, torch.Size([3, model.config.horizon_length]))
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932]],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
[ 0.9813, 1.0086, 0.9985, 0.9432, 0.8505, 0.7203, 0.5596, 0.3788,
|
||||
0.1796, -0.0264, -0.2307, -0.4255, -0.5978, -0.7642, -0.8772, -0.9670,
|
||||
-1.0110, -1.0162, -0.9848, -0.9151, -0.8016, -0.6511, -0.4707, -0.2842,
|
||||
-0.0787, 0.1260, 0.3293, 0.5104, 0.6818, 0.8155, 0.9172, 0.9843,
|
||||
1.0101, 1.0025, 0.9529, 0.8588, 0.7384, 0.5885, 0.4022, 0.2099,
|
||||
-0.0035, -0.2104, -0.4146, -0.6033, -0.7661, -0.8818, -0.9725, -1.0191,
|
||||
-1.0190, -0.9874, -0.9137, -0.8069, -0.6683, -0.4939, -0.3086, -0.1106,
|
||||
0.0846, 0.2927, 0.4832, 0.6612, 0.8031, 0.9051, 0.9772, 1.0064
|
||||
],
|
||||
device=torch_device)
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(mean_predictions[0, :64], expected_slice, atol=TOLERANCE))
|
||||
|
Loading…
Reference in New Issue
Block a user