[time series] Add Time series inputs tests (#21846)

* intial test of inputs

* added test for generation

* remove asserts

* fixed test

* Update tests/models/time_series_transformer/test_modeling_time_series_transformer.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

---------

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
Kashif Rasul 2023-03-02 20:43:35 +01:00 committed by GitHub
parent b2a41d2be4
commit db979f7588
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 22 deletions

View File

@ -1589,14 +1589,11 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
sequence_length = sequence.shape[1]
indices = [lag - shift for lag in self.config.lags_sequence]
try:
assert max(indices) + subsequences_length <= sequence_length, (
if max(indices) + subsequences_length > sequence_length:
raise ValueError(
f"lags cannot go further than history length, found lag {max(indices)} "
f"while history length is only {sequence_length}"
)
except AssertionError as e:
e.args += (max(indices), sequence_length)
raise
lagged_values = []
for lag_index in indices:
@ -1642,23 +1639,6 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
else (past_values - loc) / scale
)
inputs_length = (
self._past_length + self.config.prediction_length if future_values is not None else self._past_length
)
try:
assert inputs.shape[1] == inputs_length, (
f"input length {inputs.shape[1]} and dynamic feature lengths {inputs_length} does not match",
)
except AssertionError as e:
e.args += (inputs.shape[1], inputs_length)
raise
subsequences_length = (
self.config.context_length + self.config.prediction_length
if future_values is not None
else self.config.context_length
)
# static features
log_abs_loc = loc.abs().log1p() if self.config.input_size == 1 else loc.squeeze(1).abs().log1p()
log_scale = scale.log() if self.config.input_size == 1 else scale.squeeze(1).log()
@ -1675,10 +1655,21 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
# lagged features
subsequences_length = (
self.config.context_length + self.config.prediction_length
if future_values is not None
else self.config.context_length
)
lagged_sequence = self.get_lagged_subsequences(sequence=inputs, subsequences_length=subsequences_length)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1)
if reshaped_lagged_sequence.shape[1] != time_feat.shape[1]:
raise ValueError(
f"input length {reshaped_lagged_sequence.shape[1]} and time feature lengths {time_feat.shape[1]} does not match"
)
# transformer inputs
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
return transformer_inputs, loc, scale, static_feat

View File

@ -19,6 +19,7 @@ import tempfile
import unittest
from huggingface_hub import hf_hub_download
from parameterized import parameterized
from transformers import is_torch_available
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
@ -368,6 +369,90 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
)
@parameterized.expand(
[
(1, 5, [1]),
(1, 5, [1, 10, 15]),
(1, 5, [3, 6, 9, 10]),
(2, 5, [1, 2, 7]),
(2, 5, [2, 3, 4, 6]),
(4, 5, [1, 5, 9, 11]),
(4, 5, [7, 8, 13, 14]),
],
)
def test_create_network_inputs(self, prediction_length, context_length, lags_sequence):
history_length = max(lags_sequence) + context_length
config = TimeSeriesTransformerConfig(
prediction_length=prediction_length,
context_length=context_length,
lags_sequence=lags_sequence,
scaling=False,
num_parallel_samples=10,
num_static_categorical_features=1,
cardinality=[1],
embedding_dimension=[2],
num_static_real_features=1,
)
model = TimeSeriesTransformerModel(config)
batch = {
"static_categorical_features": torch.tensor([[0]], dtype=torch.int64),
"static_real_features": torch.tensor([[0.0]], dtype=torch.float32),
"past_time_features": torch.arange(history_length, dtype=torch.float32).view(1, history_length, 1),
"past_values": torch.arange(history_length, dtype=torch.float32).view(1, history_length),
"past_observed_mask": torch.arange(history_length, dtype=torch.float32).view(1, history_length),
}
# test with no future_target (only one step prediction)
batch["future_time_features"] = torch.arange(history_length, history_length + 1, dtype=torch.float32).view(
1, 1, 1
)
transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch)
self.assertTrue((scale == 1.0).all())
assert (loc == 0.0).all()
ref = torch.arange(max(lags_sequence), history_length, dtype=torch.float32)
for idx, lag in enumerate(lags_sequence):
assert torch.isclose(ref - lag, transformer_inputs[0, :, idx]).all()
# test with all future data
batch["future_time_features"] = torch.arange(
history_length, history_length + prediction_length, dtype=torch.float32
).view(1, prediction_length, 1)
batch["future_values"] = torch.arange(
history_length, history_length + prediction_length, dtype=torch.float32
).view(1, prediction_length)
transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch)
assert (scale == 1.0).all()
assert (loc == 0.0).all()
ref = torch.arange(max(lags_sequence), history_length + prediction_length, dtype=torch.float32)
for idx, lag in enumerate(lags_sequence):
assert torch.isclose(ref - lag, transformer_inputs[0, :, idx]).all()
# test for generation
batch.pop("future_values")
transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch)
lagged_sequence = model.get_lagged_subsequences(
sequence=batch["past_values"],
subsequences_length=1,
shift=1,
)
# assert that the last element of the lagged sequence is the one after the encoders input
assert transformer_inputs[0, ..., 0][-1] + 1 == lagged_sequence[0, ..., 0][-1]
future_values = torch.arange(history_length, history_length + prediction_length, dtype=torch.float32).view(
1, prediction_length
)
# assert that the first element of the future_values is offset by lag after the decoders input
assert lagged_sequence[0, ..., 0][-1] + lags_sequence[0] == future_values[0, ..., 0]
@is_flaky()
def test_retain_grad_hidden_states_attentions(self):
super().test_retain_grad_hidden_states_attentions()