fix style

This commit is contained in:
Kashif Rasul 2023-01-30 12:04:13 +01:00
parent 83d39df0b1
commit fdffeb819c
8 changed files with 122 additions and 161 deletions

View File

@ -290,6 +290,10 @@ _import_structure = {
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"], "models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"], "models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
"models.imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"], "models.imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"],
"models.informer": [
"INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"InformerConfig",
],
"models.jukebox": [ "models.jukebox": [
"JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP",
"JukeboxConfig", "JukeboxConfig",
@ -414,10 +418,6 @@ _import_structure = {
"TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TimeSeriesTransformerConfig", "TimeSeriesTransformerConfig",
], ],
"models.informer": [
"INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"InformerConfig",
],
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"], "models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
"models.trajectory_transformer": [ "models.trajectory_transformer": [
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
@ -1621,6 +1621,14 @@ else:
"load_tf_weights_in_imagegpt", "load_tf_weights_in_imagegpt",
] ]
) )
_import_structure["models.informer"].extend(
[
"INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"InformerForPrediction",
"InformerModel",
"InformerPreTrainedModel",
]
)
_import_structure["models.jukebox"].extend( _import_structure["models.jukebox"].extend(
[ [
"JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST", "JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -2275,14 +2283,6 @@ else:
"TimeSeriesTransformerPreTrainedModel", "TimeSeriesTransformerPreTrainedModel",
] ]
) )
_import_structure["models.informer"].extend(
[
"INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"InformerForPrediction",
"InformerModel",
"InformerPreTrainedModel",
]
)
_import_structure["models.timesformer"].extend( _import_structure["models.timesformer"].extend(
[ [
"TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", "TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -3741,6 +3741,7 @@ if TYPE_CHECKING:
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
from .models.imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig from .models.imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
from .models.informer import INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, InformerConfig
from .models.jukebox import ( from .models.jukebox import (
JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP, JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP,
JukeboxConfig, JukeboxConfig,
@ -3855,10 +3856,6 @@ if TYPE_CHECKING:
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TimeSeriesTransformerConfig, TimeSeriesTransformerConfig,
) )
from .models.informer import (
INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
InformerConfig,
)
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
from .models.trajectory_transformer import ( from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
@ -4865,6 +4862,12 @@ if TYPE_CHECKING:
ImageGPTPreTrainedModel, ImageGPTPreTrainedModel,
load_tf_weights_in_imagegpt, load_tf_weights_in_imagegpt,
) )
from .models.informer import (
INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
InformerForPrediction,
InformerModel,
InformerPreTrainedModel,
)
from .models.jukebox import ( from .models.jukebox import (
JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST, JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST,
JukeboxModel, JukeboxModel,
@ -5393,12 +5396,6 @@ if TYPE_CHECKING:
TimeSeriesTransformerModel, TimeSeriesTransformerModel,
TimeSeriesTransformerPreTrainedModel, TimeSeriesTransformerPreTrainedModel,
) )
from .models.informer import (
INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
InformerForPrediction,
InformerModel,
InformerPreTrainedModel,
)
from .models.timesformer import ( from .models.timesformer import (
TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TimesformerForVideoClassification, TimesformerForVideoClassification,

View File

@ -90,6 +90,7 @@ from . import (
hubert, hubert,
ibert, ibert,
imagegpt, imagegpt,
informer,
jukebox, jukebox,
layoutlm, layoutlm,
layoutlmv2, layoutlmv2,
@ -165,7 +166,6 @@ from . import (
tapas, tapas,
tapex, tapex,
time_series_transformer, time_series_transformer,
informer,
timesformer, timesformer,
trajectory_transformer, trajectory_transformer,
transfo_xl, transfo_xl,

View File

@ -93,6 +93,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("hubert", "HubertConfig"), ("hubert", "HubertConfig"),
("ibert", "IBertConfig"), ("ibert", "IBertConfig"),
("imagegpt", "ImageGPTConfig"), ("imagegpt", "ImageGPTConfig"),
("informer", "InformerConfig"),
("jukebox", "JukeboxConfig"), ("jukebox", "JukeboxConfig"),
("layoutlm", "LayoutLMConfig"), ("layoutlm", "LayoutLMConfig"),
("layoutlmv2", "LayoutLMv2Config"), ("layoutlmv2", "LayoutLMv2Config"),
@ -161,7 +162,6 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("table-transformer", "TableTransformerConfig"), ("table-transformer", "TableTransformerConfig"),
("tapas", "TapasConfig"), ("tapas", "TapasConfig"),
("time_series_transformer", "TimeSeriesTransformerConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"),
("informer", "InformerConfig"),
("timesformer", "TimesformerConfig"), ("timesformer", "TimesformerConfig"),
("trajectory_transformer", "TrajectoryTransformerConfig"), ("trajectory_transformer", "TrajectoryTransformerConfig"),
("transfo-xl", "TransfoXLConfig"), ("transfo-xl", "TransfoXLConfig"),
@ -258,6 +258,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("informer", "INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("jukebox", "JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("jukebox", "JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@ -319,7 +320,6 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("table-transformer", "TABLE_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("table-transformer", "TABLE_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("time_series_transformer", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("time_series_transformer", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("informer", "INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("timesformer", "TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("timesformer", "TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@ -424,6 +424,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("hubert", "Hubert"), ("hubert", "Hubert"),
("ibert", "I-BERT"), ("ibert", "I-BERT"),
("imagegpt", "ImageGPT"), ("imagegpt", "ImageGPT"),
("informer", "Informer"),
("jukebox", "Jukebox"), ("jukebox", "Jukebox"),
("layoutlm", "LayoutLM"), ("layoutlm", "LayoutLM"),
("layoutlmv2", "LayoutLMv2"), ("layoutlmv2", "LayoutLMv2"),
@ -500,7 +501,6 @@ MODEL_NAMES_MAPPING = OrderedDict(
("tapas", "TAPAS"), ("tapas", "TAPAS"),
("tapex", "TAPEX"), ("tapex", "TAPEX"),
("time_series_transformer", "Time Series Transformer"), ("time_series_transformer", "Time Series Transformer"),
("informer", "Informer"),
("timesformer", "TimeSformer"), ("timesformer", "TimeSformer"),
("trajectory_transformer", "Trajectory Transformer"), ("trajectory_transformer", "Trajectory Transformer"),
("transfo-xl", "Transformer-XL"), ("transfo-xl", "Transformer-XL"),

View File

@ -92,6 +92,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("hubert", "HubertModel"), ("hubert", "HubertModel"),
("ibert", "IBertModel"), ("ibert", "IBertModel"),
("imagegpt", "ImageGPTModel"), ("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("jukebox", "JukeboxModel"), ("jukebox", "JukeboxModel"),
("layoutlm", "LayoutLMModel"), ("layoutlm", "LayoutLMModel"),
("layoutlmv2", "LayoutLMv2Model"), ("layoutlmv2", "LayoutLMv2Model"),
@ -157,7 +158,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("table-transformer", "TableTransformerModel"), ("table-transformer", "TableTransformerModel"),
("tapas", "TapasModel"), ("tapas", "TapasModel"),
("time_series_transformer", "TimeSeriesTransformerModel"), ("time_series_transformer", "TimeSeriesTransformerModel"),
("informer", "InformerModel"),
("timesformer", "TimesformerModel"), ("timesformer", "TimesformerModel"),
("trajectory_transformer", "TrajectoryTransformerModel"), ("trajectory_transformer", "TrajectoryTransformerModel"),
("transfo-xl", "TransfoXLModel"), ("transfo-xl", "TransfoXLModel"),

View File

@ -43,10 +43,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_informer import ( from .configuration_informer import INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, InformerConfig
INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
InformerConfig,
)
try: try:
if not is_torch_available(): if not is_torch_available():

View File

@ -27,13 +27,11 @@ INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class InformerConfig(PretrainedConfig): class InformerConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`InformerModel`]. It is used to This is the configuration class to store the configuration of a [`InformerModel`]. It is used to instantiate an
instantiate an Informer model according to the specified arguments, defining the model architecture. Informer model according to the specified arguments, defining the model architecture. Instantiating a configuration
Instantiating a configuration with the defaults will yield a similar configuration to that of the Time Series with the defaults will yield a similar configuration to that of the Time Series Transformer
Transformer
[huggingface/time-series-transformer-tourism-monthly](https://huggingface.co/huggingface/time-series-transformer-tourism-monthly) [huggingface/time-series-transformer-tourism-monthly](https://huggingface.co/huggingface/time-series-transformer-tourism-monthly)
architecture. architecture.
@ -136,47 +134,47 @@ class InformerConfig(PretrainedConfig):
} }
def __init__( def __init__(
self, self,
input_size: int = 1, input_size: int = 1,
prediction_length: Optional[int] = None, prediction_length: Optional[int] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
distribution_output: str = "student_t", distribution_output: str = "student_t",
loss: str = "nll", loss: str = "nll",
lags_sequence: List[int] = None, lags_sequence: List[int] = None,
scaling: bool = True, scaling: bool = True,
num_dynamic_real_features: int = 0, num_dynamic_real_features: int = 0,
num_static_real_features: int = 0, num_static_real_features: int = 0,
num_static_categorical_features: int = 0, num_static_categorical_features: int = 0,
num_time_features: int = 0, num_time_features: int = 0,
cardinality: Optional[List[int]] = None, cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None, embedding_dimension: Optional[List[int]] = None,
encoder_ffn_dim: int = 32, encoder_ffn_dim: int = 32,
decoder_ffn_dim: int = 32, decoder_ffn_dim: int = 32,
encoder_attention_heads: int = 2, encoder_attention_heads: int = 2,
decoder_attention_heads: int = 2, decoder_attention_heads: int = 2,
encoder_layers: int = 2, encoder_layers: int = 2,
decoder_layers: int = 2, decoder_layers: int = 2,
is_encoder_decoder: bool = True, is_encoder_decoder: bool = True,
activation_function: str = "gelu", activation_function: str = "gelu",
dropout: float = 0.05, dropout: float = 0.05,
encoder_layerdrop: float = 0.1, encoder_layerdrop: float = 0.1,
decoder_layerdrop: float = 0.1, decoder_layerdrop: float = 0.1,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
activation_dropout: float = 0.1, activation_dropout: float = 0.1,
num_parallel_samples: int = 100, num_parallel_samples: int = 100,
init_std: float = 0.02, init_std: float = 0.02,
use_cache=True, use_cache=True,
# Informer arguments # Informer arguments
attn: str = "prob", attn: str = "prob",
factor: int = 5, factor: int = 5,
distil: bool = True, distil: bool = True,
**kwargs **kwargs
): ):
# time series specific configuration # time series specific configuration
self.prediction_length = prediction_length self.prediction_length = prediction_length
self.context_length = context_length or prediction_length self.context_length = context_length or prediction_length
self.distribution_output = distribution_output self.distribution_output = distribution_output
self.loss = loss # Eli: From vanilla ts transformer self.loss = loss # Eli: From vanilla ts transformer
self.input_size = input_size self.input_size = input_size
self.num_time_features = num_time_features self.num_time_features = num_time_features
self.lags_sequence = lags_sequence self.lags_sequence = lags_sequence

View File

@ -17,9 +17,12 @@
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from math import sqrt
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.distributions import ( from torch.distributions import (
AffineTransform, AffineTransform,
@ -37,11 +40,6 @@ from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_informer import InformerConfig from .configuration_informer import InformerConfig
from math import sqrt
from typing import List, Optional
import numpy as np
import torch.nn.functional as F
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -54,7 +52,6 @@ INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
class AffineTransformed(TransformedDistribution): class AffineTransformed(TransformedDistribution):
def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0): def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0):
self.scale = 1.0 if scale is None else scale self.scale = 1.0 if scale is None else scale
@ -472,6 +469,7 @@ class Seq2SeqTimeSeriesModelOutput(ModelOutput):
scale: Optional[torch.FloatTensor] = None scale: Optional[torch.FloatTensor] = None
static_features: Optional[torch.FloatTensor] = None static_features: Optional[torch.FloatTensor] = None
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer
@dataclass @dataclass
class Seq2SeqTimeSeriesPredictionOutput(ModelOutput): class Seq2SeqTimeSeriesPredictionOutput(ModelOutput):
@ -540,6 +538,7 @@ class Seq2SeqTimeSeriesPredictionOutput(ModelOutput):
scale: Optional[torch.FloatTensor] = None scale: Optional[torch.FloatTensor] = None
static_features: Optional[torch.FloatTensor] = None static_features: Optional[torch.FloatTensor] = None
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer
@dataclass @dataclass
class SampleTimeSeriesPredictionOutput(ModelOutput): class SampleTimeSeriesPredictionOutput(ModelOutput):
@ -554,9 +553,7 @@ class TriangularCausalMask:
def __init__(self, B, L, device="cpu"): def __init__(self, B, L, device="cpu"):
mask_shape = [B, 1, L, L] mask_shape = [B, 1, L, L]
with torch.no_grad(): with torch.no_grad():
self._mask = torch.triu( self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
torch.ones(mask_shape, dtype=torch.bool), diagonal=1
).to(device)
@property @property
def mask(self): def mask(self):
@ -568,9 +565,7 @@ class ProbMask:
def __init__(self, B, H, L, index, scores, device="cpu"): def __init__(self, B, H, L, index, scores, device="cpu"):
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
indicator = _mask_ex[ indicator = _mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :].to(device)
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
].to(device)
self._mask = indicator.view(scores.shape).to(device) self._mask = indicator.view(scores.shape).to(device)
@property @property
@ -597,7 +592,7 @@ class FullAttention(nn.Module):
def forward(self, queries, keys, values, attn_mask): def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape B, L, H, E = queries.shape
_, S, _, D = values.shape _, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E) scale = self.scale or 1.0 / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys) scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag: if self.mask_flag:
@ -673,14 +668,12 @@ class ProbAttention(nn.Module):
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
context_in[ context_in[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = torch.matmul(
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : attn, V
] = torch.matmul(attn, V).type_as(context_in) ).type_as(context_in)
if self.output_attention: if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
attns[ attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = attn
return (context_in, attns) return (context_in, attns)
else: else:
return (context_in, None) return (context_in, None)
@ -708,18 +701,14 @@ class ProbAttention(nn.Module):
# get the context # get the context
context = self._get_initial_context(values, L_Q) context = self._get_initial_context(values, L_Q)
# update the context with selected top_k queries # update the context with selected top_k queries
context, attn = self._update_context( context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
context, values, scores_top, index, L_Q, attn_mask
)
return context.transpose(2, 1).contiguous(), attn return context.transpose(2, 1).contiguous(), attn
# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/attn.py # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/attn.py
class AttentionLayer(nn.Module): class AttentionLayer(nn.Module):
def __init__( def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None, mix=False):
self, attention, d_model, n_heads, d_keys=None, d_values=None, mix=False
):
super(AttentionLayer, self).__init__() super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads) d_keys = d_keys or (d_model // n_heads)
@ -761,13 +750,13 @@ class ConvLayer(nn.Module):
padding=1, padding=1,
padding_mode="circular", padding_mode="circular",
) )
self.norm = nn.BatchNorm1d(c_in) # Eli question: why batchnorm here? self.norm = nn.BatchNorm1d(c_in) # Eli question: why batchnorm here?
self.activation = nn.ELU() self.activation = nn.ELU()
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
def forward(self, x): def forward(self, x):
x = self.downConv(x.permute(0, 2, 1)) x = self.downConv(x.permute(0, 2, 1))
x = self.norm(x) # Eli: why? maybe because the impl... x = self.norm(x) # Eli: why? maybe because the impl...
x = self.activation(x) x = self.activation(x)
x = self.maxPool(x) x = self.maxPool(x)
x = x.transpose(1, 2) x = x.transpose(1, 2)
@ -830,9 +819,7 @@ class DecoderLayer(nn.Module):
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0]) x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
x = self.norm1(x) x = self.norm1(x)
x = x + self.dropout( x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0])
self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
)
y = x = self.norm2(x) y = x = self.norm2(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
@ -847,8 +834,9 @@ class InformerEncoder(nn.Module):
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
Attn = ProbAttention if config.attn == "prob" else FullAttention Attn = ProbAttention if config.attn == "prob" else FullAttention
self.attn_layers = nn.ModuleList([ self.attn_layers = nn.ModuleList(
EncoderLayer( [
EncoderLayer(
AttentionLayer( AttentionLayer(
Attn( Attn(
mask_flag=False, mask_flag=False,
@ -864,8 +852,10 @@ class InformerEncoder(nn.Module):
d_ff=config.encoder_ffn_dim, d_ff=config.encoder_ffn_dim,
dropout=config.attention_dropout, dropout=config.attention_dropout,
activation=self.activation_fn, activation=self.activation_fn,
) for _ in range(config.encoder_layers) )
]) for _ in range(config.encoder_layers)
]
)
if config.distil is not None: if config.distil is not None:
self.conv_layers = nn.ModuleList([ConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)]) self.conv_layers = nn.ModuleList([ConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)])
@ -1000,22 +990,15 @@ class InformerModel(InformerPreTrainedModel):
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Returns lagged subsequences of a given sequence. Returns lagged subsequences of a given sequence. Parameters ---------- sequence : Tensor
Parameters the sequence from which lagged subsequences should be extracted. Shape: (N, T, C).
----------
sequence : Tensor
the sequence from which lagged subsequences should be extracted.
Shape: (N, T, C).
subsequences_length : int subsequences_length : int
length of the subsequences to be extracted. length of the subsequences to be extracted.
shift: int shift: int
shift the lags by this amount back. shift the lags by this amount back.
Returns Returns -------- lagged : Tensor
-------- a tensor of shape (N, S, C, I), where S = subsequences_length and I = len(indices), containing lagged
lagged : Tensor subsequences. Specifically, lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
a tensor of shape (N, S, C, I), where S = subsequences_length and
I = len(indices), containing lagged subsequences. Specifically,
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
""" """
sequence_length = sequence.shape[1] sequence_length = sequence.shape[1]
indices = [lag - shift for lag in self.config.lags_sequence] indices = [lag - shift for lag in self.config.lags_sequence]
@ -1125,24 +1108,24 @@ class InformerModel(InformerPreTrainedModel):
return self.decoder return self.decoder
def forward( def forward(
self, self,
past_values: torch.Tensor, past_values: torch.Tensor,
past_time_features: torch.Tensor, past_time_features: torch.Tensor,
past_observed_mask: torch.Tensor, past_observed_mask: torch.Tensor,
static_categorical_features: torch.Tensor, static_categorical_features: torch.Tensor,
static_real_features: torch.Tensor, static_real_features: torch.Tensor,
future_values: Optional[torch.Tensor] = None, future_values: Optional[torch.Tensor] = None,
future_time_features: Optional[torch.Tensor] = None, future_time_features: Optional[torch.Tensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Seq2SeqTimeSeriesModelOutput, Tuple]: ) -> Union[Seq2SeqTimeSeriesModelOutput, Tuple]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -1178,7 +1161,7 @@ class InformerModel(InformerPreTrainedModel):
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
) )
dec_input = transformer_inputs[:, self.config.context_length:, ...] dec_input = transformer_inputs[:, self.config.context_length :, ...]
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs_embeds=dec_input, inputs_embeds=dec_input,
attention_mask=decoder_attention_mask, attention_mask=decoder_attention_mask,
@ -1462,6 +1445,3 @@ class InformerForPrediction(InformerPreTrainedModel):
(-1, num_parallel_samples, self.config.prediction_length) + self.target_shape, (-1, num_parallel_samples, self.config.prediction_length) + self.target_shape,
) )
) )

View File

@ -31,15 +31,8 @@ TOLERANCE = 1e-4
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import ( from transformers import InformerConfig, InformerForPrediction, InformerModel
InformerConfig, from transformers.models.informer.modeling_informer import InformerDecoder, InformerEncoder
InformerForPrediction,
InformerModel,
)
from transformers.models.informer.modeling_informer import (
InformerDecoder,
InformerEncoder,
)
@require_torch @require_torch
@ -171,9 +164,7 @@ class InformerModelTester:
@require_torch @require_torch
class InformerModelTest(ModelTesterMixin, unittest.TestCase): class InformerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (InformerModel, InformerForPrediction) if is_torch_available() else ()
(InformerModel, InformerForPrediction) if is_torch_available() else ()
)
all_generative_model_classes = (InformerForPrediction,) if is_torch_available() else () all_generative_model_classes = (InformerForPrediction,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
@ -374,9 +365,7 @@ def prepare_batch(filename="train-batch.pt"):
@slow @slow
class InformerModelIntegrationTests(unittest.TestCase): class InformerModelIntegrationTests(unittest.TestCase):
def test_inference_no_head(self): def test_inference_no_head(self):
model = InformerModel.from_pretrained("huggingface/time-series-transformer-tourism-monthly").to( model = InformerModel.from_pretrained("huggingface/time-series-transformer-tourism-monthly").to(torch_device)
torch_device
)
batch = prepare_batch() batch = prepare_batch()
with torch.no_grad(): with torch.no_grad():
@ -399,9 +388,9 @@ class InformerModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_head(self): def test_inference_head(self):
model = InformerForPrediction.from_pretrained( model = InformerForPrediction.from_pretrained("huggingface/time-series-transformer-tourism-monthly").to(
"huggingface/time-series-transformer-tourism-monthly" torch_device
).to(torch_device) )
batch = prepare_batch("val-batch.pt") batch = prepare_batch("val-batch.pt")
with torch.no_grad(): with torch.no_grad():
output = model( output = model(
@ -421,9 +410,9 @@ class InformerModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_seq_to_seq_generation(self): def test_seq_to_seq_generation(self):
model = InformerForPrediction.from_pretrained( model = InformerForPrediction.from_pretrained("huggingface/time-series-transformer-tourism-monthly").to(
"huggingface/time-series-transformer-tourism-monthly" torch_device
).to(torch_device) )
batch = prepare_batch("val-batch.pt") batch = prepare_batch("val-batch.pt")
with torch.no_grad(): with torch.no_grad():
outputs = model.generate( outputs = model.generate(