diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index e94174400a7..4708bde705b 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -1589,14 +1589,11 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): sequence_length = sequence.shape[1] indices = [lag - shift for lag in self.config.lags_sequence] - try: - assert max(indices) + subsequences_length <= sequence_length, ( + if max(indices) + subsequences_length > sequence_length: + raise ValueError( f"lags cannot go further than history length, found lag {max(indices)} " f"while history length is only {sequence_length}" ) - except AssertionError as e: - e.args += (max(indices), sequence_length) - raise lagged_values = [] for lag_index in indices: @@ -1642,23 +1639,6 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): else (past_values - loc) / scale ) - inputs_length = ( - self._past_length + self.config.prediction_length if future_values is not None else self._past_length - ) - try: - assert inputs.shape[1] == inputs_length, ( - f"input length {inputs.shape[1]} and dynamic feature lengths {inputs_length} does not match", - ) - except AssertionError as e: - e.args += (inputs.shape[1], inputs_length) - raise - - subsequences_length = ( - self.config.context_length + self.config.prediction_length - if future_values is not None - else self.config.context_length - ) - # static features log_abs_loc = loc.abs().log1p() if self.config.input_size == 1 else loc.squeeze(1).abs().log1p() log_scale = scale.log() if self.config.input_size == 1 else scale.squeeze(1).log() @@ -1675,10 +1655,21 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): features = torch.cat((expanded_static_feat, time_feat), dim=-1) # lagged features + subsequences_length = ( + self.config.context_length + self.config.prediction_length + if future_values is not None + else self.config.context_length + ) lagged_sequence = self.get_lagged_subsequences(sequence=inputs, subsequences_length=subsequences_length) lags_shape = lagged_sequence.shape reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1) + if reshaped_lagged_sequence.shape[1] != time_feat.shape[1]: + raise ValueError( + f"input length {reshaped_lagged_sequence.shape[1]} and time feature lengths {time_feat.shape[1]} does not match" + ) + + # transformer inputs transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) return transformer_inputs, loc, scale, static_feat diff --git a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py index 1ea8d8ef671..7f14f29de04 100644 --- a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py +++ b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py @@ -19,6 +19,7 @@ import tempfile import unittest from huggingface_hub import hf_hub_download +from parameterized import parameterized from transformers import is_torch_available from transformers.testing_utils import is_flaky, require_torch, slow, torch_device @@ -368,6 +369,90 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit [self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length], ) + @parameterized.expand( + [ + (1, 5, [1]), + (1, 5, [1, 10, 15]), + (1, 5, [3, 6, 9, 10]), + (2, 5, [1, 2, 7]), + (2, 5, [2, 3, 4, 6]), + (4, 5, [1, 5, 9, 11]), + (4, 5, [7, 8, 13, 14]), + ], + ) + def test_create_network_inputs(self, prediction_length, context_length, lags_sequence): + history_length = max(lags_sequence) + context_length + + config = TimeSeriesTransformerConfig( + prediction_length=prediction_length, + context_length=context_length, + lags_sequence=lags_sequence, + scaling=False, + num_parallel_samples=10, + num_static_categorical_features=1, + cardinality=[1], + embedding_dimension=[2], + num_static_real_features=1, + ) + model = TimeSeriesTransformerModel(config) + + batch = { + "static_categorical_features": torch.tensor([[0]], dtype=torch.int64), + "static_real_features": torch.tensor([[0.0]], dtype=torch.float32), + "past_time_features": torch.arange(history_length, dtype=torch.float32).view(1, history_length, 1), + "past_values": torch.arange(history_length, dtype=torch.float32).view(1, history_length), + "past_observed_mask": torch.arange(history_length, dtype=torch.float32).view(1, history_length), + } + + # test with no future_target (only one step prediction) + batch["future_time_features"] = torch.arange(history_length, history_length + 1, dtype=torch.float32).view( + 1, 1, 1 + ) + transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch) + + self.assertTrue((scale == 1.0).all()) + assert (loc == 0.0).all() + + ref = torch.arange(max(lags_sequence), history_length, dtype=torch.float32) + + for idx, lag in enumerate(lags_sequence): + assert torch.isclose(ref - lag, transformer_inputs[0, :, idx]).all() + + # test with all future data + batch["future_time_features"] = torch.arange( + history_length, history_length + prediction_length, dtype=torch.float32 + ).view(1, prediction_length, 1) + batch["future_values"] = torch.arange( + history_length, history_length + prediction_length, dtype=torch.float32 + ).view(1, prediction_length) + transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch) + + assert (scale == 1.0).all() + assert (loc == 0.0).all() + + ref = torch.arange(max(lags_sequence), history_length + prediction_length, dtype=torch.float32) + + for idx, lag in enumerate(lags_sequence): + assert torch.isclose(ref - lag, transformer_inputs[0, :, idx]).all() + + # test for generation + batch.pop("future_values") + transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch) + + lagged_sequence = model.get_lagged_subsequences( + sequence=batch["past_values"], + subsequences_length=1, + shift=1, + ) + # assert that the last element of the lagged sequence is the one after the encoders input + assert transformer_inputs[0, ..., 0][-1] + 1 == lagged_sequence[0, ..., 0][-1] + + future_values = torch.arange(history_length, history_length + prediction_length, dtype=torch.float32).view( + 1, prediction_length + ) + # assert that the first element of the future_values is offset by lag after the decoders input + assert lagged_sequence[0, ..., 0][-1] + lags_sequence[0] == future_values[0, ..., 0] + @is_flaky() def test_retain_grad_hidden_states_attentions(self): super().test_retain_grad_hidden_states_attentions()