mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
[TimesFM] use the main revison instead of revision for integration test (#37558)
* use the main revison instead of revision * test prediction * check larger time steps
This commit is contained in:
parent
3bc44eaaee
commit
dc06e7cecd
@ -48,6 +48,7 @@ def write_model(model_path, safe_serialization=True, huggingface_repo_id="google
|
|||||||
num_layers=50,
|
num_layers=50,
|
||||||
model_dims=1280,
|
model_dims=1280,
|
||||||
use_positional_embedding=False,
|
use_positional_embedding=False,
|
||||||
|
context_len=2048,
|
||||||
),
|
),
|
||||||
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
|
checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=huggingface_repo_id),
|
||||||
)
|
)
|
||||||
@ -159,6 +160,7 @@ def check_outputs(model_path, huggingface_repo_id):
|
|||||||
input_patch_len=32,
|
input_patch_len=32,
|
||||||
output_patch_len=128,
|
output_patch_len=128,
|
||||||
num_layers=50,
|
num_layers=50,
|
||||||
|
context_len=2048,
|
||||||
model_dims=1280,
|
model_dims=1280,
|
||||||
use_positional_embedding=False,
|
use_positional_embedding=False,
|
||||||
point_forecast_mode="mean",
|
point_forecast_mode="mean",
|
||||||
|
@ -171,10 +171,8 @@ class TimesFmModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
class TimesFmModelIntegrationTests(unittest.TestCase):
|
class TimesFmModelIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_no_head(self):
|
def test_inference(self):
|
||||||
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch", revision="refs/pr/7").to(
|
model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch").to(torch_device)
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
forecast_input = [
|
forecast_input = [
|
||||||
np.sin(np.linspace(0, 20, 100)),
|
np.sin(np.linspace(0, 20, 100)),
|
||||||
np.sin(np.linspace(0, 20, 200)),
|
np.sin(np.linspace(0, 20, 200)),
|
||||||
@ -184,14 +182,21 @@ class TimesFmModelIntegrationTests(unittest.TestCase):
|
|||||||
frequency_input = [0, 1, 2]
|
frequency_input = [0, 1, 2]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(past_values=forecast_input_tensor, freq=frequency_input).last_hidden_state
|
output = model(past_values=forecast_input_tensor, freq=frequency_input)
|
||||||
|
|
||||||
self.assertEqual(
|
mean_predictions = output.mean_predictions
|
||||||
output.shape,
|
self.assertEqual(mean_predictions.shape, torch.Size([3, model.config.horizon_length]))
|
||||||
torch.Size([3, model.config.context_length // model.config.patch_length, model.config.hidden_size]),
|
# fmt: off
|
||||||
)
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[[-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932], [-0.4267, -0.7273, -0.3932]],
|
[ 0.9813, 1.0086, 0.9985, 0.9432, 0.8505, 0.7203, 0.5596, 0.3788,
|
||||||
device=torch_device,
|
0.1796, -0.0264, -0.2307, -0.4255, -0.5978, -0.7642, -0.8772, -0.9670,
|
||||||
)
|
-1.0110, -1.0162, -0.9848, -0.9151, -0.8016, -0.6511, -0.4707, -0.2842,
|
||||||
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
|
-0.0787, 0.1260, 0.3293, 0.5104, 0.6818, 0.8155, 0.9172, 0.9843,
|
||||||
|
1.0101, 1.0025, 0.9529, 0.8588, 0.7384, 0.5885, 0.4022, 0.2099,
|
||||||
|
-0.0035, -0.2104, -0.4146, -0.6033, -0.7661, -0.8818, -0.9725, -1.0191,
|
||||||
|
-1.0190, -0.9874, -0.9137, -0.8069, -0.6683, -0.4939, -0.3086, -0.1106,
|
||||||
|
0.0846, 0.2927, 0.4832, 0.6612, 0.8031, 0.9051, 0.9772, 1.0064
|
||||||
|
],
|
||||||
|
device=torch_device)
|
||||||
|
# fmt: on
|
||||||
|
self.assertTrue(torch.allclose(mean_predictions[0, :64], expected_slice, atol=TOLERANCE))
|
||||||
|
Loading…
Reference in New Issue
Block a user