transformers/tests/models/informer/test_modeling_informer.py
Eli Simhayev 8abe4930d3
[Time-Series] informer model (#21099)
* added informer to gitignore

* added informer to gitignore

* WIP informer2020

* added checking that instantiate works

* added config using gluonTS by kashif

* WIP config

* adding informeConfig. need to remove FeatureEmbedder

* done InformerConfig, but need to change the names

* Done informer model init. working on enc-dec

* added things to address, after reading again enc-dec in the paper

* done modeling - checking initialization work

* added informer to gitignore

* WIP informer2020

* added checking that instantiate works

* added config using gluonTS by kashif

* WIP config

* adding informeConfig. need to remove FeatureEmbedder

* done InformerConfig, but need to change the names

* Done informer model init. working on enc-dec

* added things to address, after reading again enc-dec in the paper

* done modeling - checking initialization work

* moved enc-dec init to InformerEncoder/Decoder init

* added 'init_std' to config, now model init works!

* WIP conversion script, and added code sources

* WIP conversion script: loading original informer pth works

* WIP conversion script: change defaults in the config

* WIP conversion script: supporting Informer input embedding

* WIP conversion script: added parameters for the informer embed

* WIP conversion script: change dim_feedforward=2048

* WIP conversion script: remove unused args for loading checkpoint

* just cleaning up

* DataEmbedding removed, after thinking with Kashif

* working on forward pass

* WIP forward pass: trying to establish working batch for forward pass

* cleaning and finalizing

* adding HF names and docs

* init after cleaning works

* WIP in tests

* added docs for the informer specific args

* fix style

* undo change

* cleaning informer, now need to work only enc-dec

* initial enc-dec classes

* added encoder and decoder

* added todo

* add todos for conv_layers

* added decoder docs from vanilla

* added encoder docs from vanilla

* remove encoder decoder from the original informer

* removed AttentionLayer from the original paper

* removed TriangularCausalMask, same as decoder_attention_mask

* initial sparse attention

* use conv_layers

* fixed test_config test

* fix parenthesis when itearting zip(layers, conv_layers)

* error found in prob attention, added sizes as comments

* fix sizes

* added proposal for q_reduce indexing, and remove unused

* WIP ProbMask, and changed factor=2 for testing

* remove unused libs for this PR for creating the env

* fix checking the attn_weights.size() after bmm

* Q_reduce: changed from torch.gather to simple slicing

* WIP calculate final attn_output

* finish adding v_aggregated, attn_output ready

* changed tgt_len to u in attention_mask, need to fix the size error

* comment attention_mask for encoder, and fix if cond for v_agg

* added ProbMask support (wip), removed old original code

* finished ProbMask 😃

* Revert "remove unused libs for this PR for creating the env"

This reverts commit 11a081e09e.

* fixes

* make style

* fix initial tests

* fix more tests

* dry

* make style

* remove unused files

* style

* added integration tests

* fix num_static_real_features

* fix header

* remove unused function

* fix example

* fix docs

* Update src/transformers/models/informer/configuration_informer.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/modeling_informer.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/configuration_informer.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/configuration_informer.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/configuration_informer.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/configuration_informer.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* fixes for reviewer

* use prediction_length from model

* fix style

* fixed informer.mdx

* added to index

* updated readme

* undo

* make fix-copies

* typo

* fix copy

* added Informer to toctree

* in order

* fixed comments

* remove unneeded new lines in docs

* make static real and cat optional

* fix use of distil conv layers

* fixed integration test

* added checkpoint for convlayer

* make fix-copies

* updated from time series model

* make fix-copies

* copy decoder

* fix unit tests

* updated scaling config

* fix integration tests

* IGNORE_NON_TESTED

* IGNORE_NON_AUTO_CONFIGURED

* IGNORE_NON_AUTO_CONFIGURED

* updated check configs

* fix formatting

* undo change from time series

* prediction_length should not be None

* aliign with the blog: prettify ProbSparse and change attention_factor  to sampling_factor

* make style

* make fix-copies

* niels CR: update contributed by

* niels CR: update configuration_informer.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* niels CR: update kashif -> huggingface

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* niels CR: `sampling_factor` only relevant when `attention_type`=prob

* make style

* fixed U_part: added multiplication by `L_Q`

* fixed bug: remove `is not None` from `if config.distil`

* fixed test: `decoder_seq_length` to `encoder_seq_length` in cross_attentions check

* fix integration tests

* updated model hub

* do not shift as in training

* undo

* fix make-copies

* make fix-copies

* added `if prediction_length is None`

* changed `ProbSparseAttention` to `InformerProbSparseAttention`

* changed `V_sum` -> `v_mean_dim_time`

* changed `ConvLayer` to `InformerConvLayer` and fixed `super()`

* TimeSeriesTansformer->Informer in decoder's Copied from

* more descriptive in ProbSparse

* make style

* fix coped from

* Revert "added `if prediction_length is None`"

This reverts commit b4cbddfa05.

* fixed indent

* use InformerSinusoidalPositionalEmbedding

* make fix-style

* fix from #21860

* fix name

* make fix-copies

* use time series utils

* fix dec num_heads

* docstring

* added time series util doc

* _import_structure

* formatting

* changes from review

* make style

* fix docs

* fix doc

* removed NegativeLogLikelihood

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
2023-03-07 21:36:38 +01:00

512 lines
21 KiB
Python

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Informer model. """
import inspect
import tempfile
import unittest
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import is_torch_available
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
TOLERANCE = 1e-4
if is_torch_available():
import torch
from transformers import InformerConfig, InformerForPrediction, InformerModel
from transformers.models.informer.modeling_informer import InformerDecoder, InformerEncoder
@require_torch
class InformerModelTester:
def __init__(
self,
parent,
batch_size=13,
prediction_length=7,
context_length=14,
cardinality=19,
embedding_dimension=5,
num_time_features=4,
is_training=True,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
lags_sequence=[1, 2, 3, 4, 5],
sampling_factor=10,
distil=False,
):
self.parent = parent
self.batch_size = batch_size
self.prediction_length = prediction_length
self.context_length = context_length
self.cardinality = cardinality
self.num_time_features = num_time_features
self.lags_sequence = lags_sequence
self.embedding_dimension = embedding_dimension
self.is_training = is_training
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.encoder_seq_length = min(
sampling_factor * np.ceil(np.log1p(context_length)).astype("int").item(), context_length
)
self.decoder_seq_length = min(
sampling_factor * np.ceil(np.log1p(prediction_length)).astype("int").item(), prediction_length
)
self.sampling_factor = sampling_factor
self.distil = distil
def get_config(self):
return InformerConfig(
prediction_length=self.prediction_length,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
context_length=self.context_length,
lags_sequence=self.lags_sequence,
num_time_features=self.num_time_features,
num_static_categorical_features=1,
num_static_real_features=1,
cardinality=[self.cardinality],
embedding_dimension=[self.embedding_dimension],
sampling_factor=self.sampling_factor,
distil=self.distil,
)
def prepare_informer_inputs_dict(self, config):
_past_length = config.context_length + max(config.lags_sequence)
static_categorical_features = ids_tensor([self.batch_size, 1], config.cardinality[0])
static_real_features = floats_tensor([self.batch_size, 1])
past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
past_values = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length])
# decoder inputs
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])
future_values = floats_tensor([self.batch_size, config.prediction_length])
inputs_dict = {
"past_values": past_values,
"static_categorical_features": static_categorical_features,
"static_real_features": static_real_features,
"past_time_features": past_time_features,
"past_observed_mask": past_observed_mask,
"future_time_features": future_time_features,
"future_values": future_values,
}
return inputs_dict
def prepare_config_and_inputs(self):
config = self.get_config()
inputs_dict = self.prepare_informer_inputs_dict(config)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = InformerModel(config=config).to(torch_device).eval()
outputs = model(**inputs_dict)
encoder_last_hidden_state = outputs.encoder_last_hidden_state
last_hidden_state = outputs.last_hidden_state
with tempfile.TemporaryDirectory() as tmpdirname:
encoder = model.get_encoder()
encoder.save_pretrained(tmpdirname)
encoder = InformerEncoder.from_pretrained(tmpdirname).to(torch_device)
transformer_inputs, _, _, _ = model.create_network_inputs(**inputs_dict)
enc_input = transformer_inputs[:, : config.context_length, ...]
dec_input = transformer_inputs[:, config.context_length :, ...]
encoder_last_hidden_state_2 = encoder(inputs_embeds=enc_input)[0]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
decoder = InformerDecoder.from_pretrained(tmpdirname).to(torch_device)
last_hidden_state_2 = decoder(
inputs_embeds=dec_input,
encoder_hidden_states=encoder_last_hidden_state,
)[0]
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
@require_torch
class InformerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (InformerModel, InformerForPrediction) if is_torch_available() else ()
all_generative_model_classes = (InformerForPrediction,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torchscript = False
test_inputs_embeds = False
test_model_common_attributes = False
def setUp(self):
self.model_tester = InformerModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=InformerConfig,
has_text_modality=False,
prediction_length=self.model_tester.prediction_length,
)
def test_config(self):
self.config_tester.run_common_tests()
def test_save_load_strict(self):
config, _ = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], [])
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.context_length
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
seq_length = seq_length * self.model_tester.chunk_length
else:
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
if config.is_encoder_decoder:
hidden_states = outputs.decoder_hidden_states
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "prediction_length", seq_len)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[decoder_seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# Ignore since we have no tokens embeddings
def test_resize_tokens_embeddings(self):
pass
def test_model_outputs_equivalence(self):
pass
def test_determinism(self):
pass
# # Input is 'static_categorical_features' not 'input_ids'
def test_model_main_input_name(self):
model_signature = inspect.signature(getattr(InformerModel, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(InformerModel.main_input_name, observed_main_input_name)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = [
"past_values",
"past_time_features",
"past_observed_mask",
"static_categorical_features",
"static_real_features",
"future_values",
"future_time_features",
]
expected_arg_names.extend(
[
"future_observed_mask",
"decoder_attention_mask",
"head_mask",
"decoder_head_mask",
"cross_attn_head_mask",
"encoder_outputs",
"past_key_values",
"output_hidden_states",
"output_attentions",
"use_cache",
"return_dict",
]
if "future_observed_mask" in arg_names
else [
"decoder_attention_mask",
"head_mask",
"decoder_head_mask",
"cross_attn_head_mask",
"encoder_outputs",
"past_key_values",
"output_hidden_states",
"output_attentions",
"use_cache",
"return_dict",
]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
context_length = getattr(self.model_tester, "context_length", seq_len)
prediction_length = getattr(self.model_tester, "prediction_length", seq_len)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, context_length],
)
out_len = len(outputs)
correct_outlen = 7
if "last_hidden_state" in outputs:
correct_outlen += 1
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
if "loss" in outputs:
correct_outlen += 1
if "params" in outputs:
correct_outlen += 1
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, prediction_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_seq_length,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + 2, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, context_length],
)
@is_flaky()
def test_retain_grad_hidden_states_attentions(self):
super().test_retain_grad_hidden_states_attentions()
def prepare_batch(filename="train-batch.pt"):
file = hf_hub_download(repo_id="kashif/tourism-monthly-batch", filename=filename, repo_type="dataset")
batch = torch.load(file, map_location=torch_device)
return batch
@require_torch
@slow
class InformerModelIntegrationTests(unittest.TestCase):
def test_inference_no_head(self):
model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly").to(torch_device)
batch = prepare_batch()
torch.manual_seed(0)
with torch.no_grad():
output = model(
past_values=batch["past_values"],
past_time_features=batch["past_time_features"],
past_observed_mask=batch["past_observed_mask"],
static_categorical_features=batch["static_categorical_features"],
future_values=batch["future_values"],
future_time_features=batch["future_time_features"],
).last_hidden_state
expected_shape = torch.Size((64, model.config.context_length, model.config.d_model))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[0.4699, 0.7295, 0.8967], [0.4858, 0.3810, 0.9641], [-0.0233, 0.3608, 1.0303]],
device=torch_device,
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_head(self):
model = InformerForPrediction.from_pretrained("huggingface/informer-tourism-monthly").to(torch_device)
batch = prepare_batch("val-batch.pt")
torch.manual_seed(0)
with torch.no_grad():
output = model(
past_values=batch["past_values"],
past_time_features=batch["past_time_features"],
past_observed_mask=batch["past_observed_mask"],
static_categorical_features=batch["static_categorical_features"],
future_time_features=batch["future_time_features"],
).encoder_last_hidden_state
# encoder distils the context length to 1/8th of the original length
expected_shape = torch.Size((64, model.config.context_length // 8, model.config.d_model))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[0.4170, 0.9067, 0.8153], [0.3004, 0.7574, 0.7066], [0.6803, -0.6323, 1.2802]], device=torch_device
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_seq_to_seq_generation(self):
model = InformerForPrediction.from_pretrained("huggingface/informer-tourism-monthly").to(torch_device)
batch = prepare_batch("val-batch.pt")
torch.manual_seed(0)
with torch.no_grad():
outputs = model.generate(
static_categorical_features=batch["static_categorical_features"],
past_time_features=batch["past_time_features"],
past_values=batch["past_values"],
future_time_features=batch["future_time_features"],
past_observed_mask=batch["past_observed_mask"],
)
expected_shape = torch.Size((64, model.config.num_parallel_samples, model.config.prediction_length))
self.assertEqual(outputs.sequences.shape, expected_shape)
expected_slice = torch.tensor([3400.8005, 4289.2637, 7101.9209], device=torch_device)
mean_prediction = outputs.sequences.mean(dim=1)
self.assertTrue(torch.allclose(mean_prediction[0, -3:], expected_slice, rtol=1e-1))