mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix style
This commit is contained in:
parent
83d39df0b1
commit
fdffeb819c
@ -290,6 +290,10 @@ _import_structure = {
|
|||||||
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
|
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
|
||||||
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
|
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
|
||||||
"models.imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"],
|
"models.imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"],
|
||||||
|
"models.informer": [
|
||||||
|
"INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"InformerConfig",
|
||||||
|
],
|
||||||
"models.jukebox": [
|
"models.jukebox": [
|
||||||
"JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
"JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
"JukeboxConfig",
|
"JukeboxConfig",
|
||||||
@ -414,10 +418,6 @@ _import_structure = {
|
|||||||
"TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
"TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
"TimeSeriesTransformerConfig",
|
"TimeSeriesTransformerConfig",
|
||||||
],
|
],
|
||||||
"models.informer": [
|
|
||||||
"INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
|
||||||
"InformerConfig",
|
|
||||||
],
|
|
||||||
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
|
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
|
||||||
"models.trajectory_transformer": [
|
"models.trajectory_transformer": [
|
||||||
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
@ -1621,6 +1621,14 @@ else:
|
|||||||
"load_tf_weights_in_imagegpt",
|
"load_tf_weights_in_imagegpt",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.informer"].extend(
|
||||||
|
[
|
||||||
|
"INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"InformerForPrediction",
|
||||||
|
"InformerModel",
|
||||||
|
"InformerPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.jukebox"].extend(
|
_import_structure["models.jukebox"].extend(
|
||||||
[
|
[
|
||||||
"JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -2275,14 +2283,6 @@ else:
|
|||||||
"TimeSeriesTransformerPreTrainedModel",
|
"TimeSeriesTransformerPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.informer"].extend(
|
|
||||||
[
|
|
||||||
"INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"InformerForPrediction",
|
|
||||||
"InformerModel",
|
|
||||||
"InformerPreTrainedModel",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_import_structure["models.timesformer"].extend(
|
_import_structure["models.timesformer"].extend(
|
||||||
[
|
[
|
||||||
"TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -3741,6 +3741,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
|
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
|
||||||
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||||
from .models.imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
|
from .models.imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
|
||||||
|
from .models.informer import INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, InformerConfig
|
||||||
from .models.jukebox import (
|
from .models.jukebox import (
|
||||||
JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
JukeboxConfig,
|
JukeboxConfig,
|
||||||
@ -3855,10 +3856,6 @@ if TYPE_CHECKING:
|
|||||||
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
TimeSeriesTransformerConfig,
|
TimeSeriesTransformerConfig,
|
||||||
)
|
)
|
||||||
from .models.informer import (
|
|
||||||
INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
||||||
InformerConfig,
|
|
||||||
)
|
|
||||||
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
|
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
|
||||||
from .models.trajectory_transformer import (
|
from .models.trajectory_transformer import (
|
||||||
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
@ -4865,6 +4862,12 @@ if TYPE_CHECKING:
|
|||||||
ImageGPTPreTrainedModel,
|
ImageGPTPreTrainedModel,
|
||||||
load_tf_weights_in_imagegpt,
|
load_tf_weights_in_imagegpt,
|
||||||
)
|
)
|
||||||
|
from .models.informer import (
|
||||||
|
INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
InformerForPrediction,
|
||||||
|
InformerModel,
|
||||||
|
InformerPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.jukebox import (
|
from .models.jukebox import (
|
||||||
JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST,
|
JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
JukeboxModel,
|
JukeboxModel,
|
||||||
@ -5393,12 +5396,6 @@ if TYPE_CHECKING:
|
|||||||
TimeSeriesTransformerModel,
|
TimeSeriesTransformerModel,
|
||||||
TimeSeriesTransformerPreTrainedModel,
|
TimeSeriesTransformerPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.informer import (
|
|
||||||
INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
InformerForPrediction,
|
|
||||||
InformerModel,
|
|
||||||
InformerPreTrainedModel,
|
|
||||||
)
|
|
||||||
from .models.timesformer import (
|
from .models.timesformer import (
|
||||||
TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TIMESFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TimesformerForVideoClassification,
|
TimesformerForVideoClassification,
|
||||||
|
@ -90,6 +90,7 @@ from . import (
|
|||||||
hubert,
|
hubert,
|
||||||
ibert,
|
ibert,
|
||||||
imagegpt,
|
imagegpt,
|
||||||
|
informer,
|
||||||
jukebox,
|
jukebox,
|
||||||
layoutlm,
|
layoutlm,
|
||||||
layoutlmv2,
|
layoutlmv2,
|
||||||
@ -165,7 +166,6 @@ from . import (
|
|||||||
tapas,
|
tapas,
|
||||||
tapex,
|
tapex,
|
||||||
time_series_transformer,
|
time_series_transformer,
|
||||||
informer,
|
|
||||||
timesformer,
|
timesformer,
|
||||||
trajectory_transformer,
|
trajectory_transformer,
|
||||||
transfo_xl,
|
transfo_xl,
|
||||||
|
@ -93,6 +93,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("hubert", "HubertConfig"),
|
("hubert", "HubertConfig"),
|
||||||
("ibert", "IBertConfig"),
|
("ibert", "IBertConfig"),
|
||||||
("imagegpt", "ImageGPTConfig"),
|
("imagegpt", "ImageGPTConfig"),
|
||||||
|
("informer", "InformerConfig"),
|
||||||
("jukebox", "JukeboxConfig"),
|
("jukebox", "JukeboxConfig"),
|
||||||
("layoutlm", "LayoutLMConfig"),
|
("layoutlm", "LayoutLMConfig"),
|
||||||
("layoutlmv2", "LayoutLMv2Config"),
|
("layoutlmv2", "LayoutLMv2Config"),
|
||||||
@ -161,7 +162,6 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("table-transformer", "TableTransformerConfig"),
|
("table-transformer", "TableTransformerConfig"),
|
||||||
("tapas", "TapasConfig"),
|
("tapas", "TapasConfig"),
|
||||||
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
||||||
("informer", "InformerConfig"),
|
|
||||||
("timesformer", "TimesformerConfig"),
|
("timesformer", "TimesformerConfig"),
|
||||||
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
||||||
("transfo-xl", "TransfoXLConfig"),
|
("transfo-xl", "TransfoXLConfig"),
|
||||||
@ -258,6 +258,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
|||||||
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
|
("informer", "INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("jukebox", "JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("jukebox", "JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
@ -319,7 +320,6 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
|||||||
("table-transformer", "TABLE_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("table-transformer", "TABLE_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("time_series_transformer", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("time_series_transformer", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("informer", "INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
|
||||||
("timesformer", "TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("timesformer", "TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
@ -424,6 +424,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("hubert", "Hubert"),
|
("hubert", "Hubert"),
|
||||||
("ibert", "I-BERT"),
|
("ibert", "I-BERT"),
|
||||||
("imagegpt", "ImageGPT"),
|
("imagegpt", "ImageGPT"),
|
||||||
|
("informer", "Informer"),
|
||||||
("jukebox", "Jukebox"),
|
("jukebox", "Jukebox"),
|
||||||
("layoutlm", "LayoutLM"),
|
("layoutlm", "LayoutLM"),
|
||||||
("layoutlmv2", "LayoutLMv2"),
|
("layoutlmv2", "LayoutLMv2"),
|
||||||
@ -500,7 +501,6 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("tapas", "TAPAS"),
|
("tapas", "TAPAS"),
|
||||||
("tapex", "TAPEX"),
|
("tapex", "TAPEX"),
|
||||||
("time_series_transformer", "Time Series Transformer"),
|
("time_series_transformer", "Time Series Transformer"),
|
||||||
("informer", "Informer"),
|
|
||||||
("timesformer", "TimeSformer"),
|
("timesformer", "TimeSformer"),
|
||||||
("trajectory_transformer", "Trajectory Transformer"),
|
("trajectory_transformer", "Trajectory Transformer"),
|
||||||
("transfo-xl", "Transformer-XL"),
|
("transfo-xl", "Transformer-XL"),
|
||||||
|
@ -92,6 +92,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("hubert", "HubertModel"),
|
("hubert", "HubertModel"),
|
||||||
("ibert", "IBertModel"),
|
("ibert", "IBertModel"),
|
||||||
("imagegpt", "ImageGPTModel"),
|
("imagegpt", "ImageGPTModel"),
|
||||||
|
("informer", "InformerModel"),
|
||||||
("jukebox", "JukeboxModel"),
|
("jukebox", "JukeboxModel"),
|
||||||
("layoutlm", "LayoutLMModel"),
|
("layoutlm", "LayoutLMModel"),
|
||||||
("layoutlmv2", "LayoutLMv2Model"),
|
("layoutlmv2", "LayoutLMv2Model"),
|
||||||
@ -157,7 +158,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("table-transformer", "TableTransformerModel"),
|
("table-transformer", "TableTransformerModel"),
|
||||||
("tapas", "TapasModel"),
|
("tapas", "TapasModel"),
|
||||||
("time_series_transformer", "TimeSeriesTransformerModel"),
|
("time_series_transformer", "TimeSeriesTransformerModel"),
|
||||||
("informer", "InformerModel"),
|
|
||||||
("timesformer", "TimesformerModel"),
|
("timesformer", "TimesformerModel"),
|
||||||
("trajectory_transformer", "TrajectoryTransformerModel"),
|
("trajectory_transformer", "TrajectoryTransformerModel"),
|
||||||
("transfo-xl", "TransfoXLModel"),
|
("transfo-xl", "TransfoXLModel"),
|
||||||
|
@ -43,10 +43,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_informer import (
|
from .configuration_informer import INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, InformerConfig
|
||||||
INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
||||||
InformerConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
|
@ -27,13 +27,11 @@ INFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InformerConfig(PretrainedConfig):
|
class InformerConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`InformerModel`]. It is used to
|
This is the configuration class to store the configuration of a [`InformerModel`]. It is used to instantiate an
|
||||||
instantiate an Informer model according to the specified arguments, defining the model architecture.
|
Informer model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the Time Series
|
with the defaults will yield a similar configuration to that of the Time Series Transformer
|
||||||
Transformer
|
|
||||||
[huggingface/time-series-transformer-tourism-monthly](https://huggingface.co/huggingface/time-series-transformer-tourism-monthly)
|
[huggingface/time-series-transformer-tourism-monthly](https://huggingface.co/huggingface/time-series-transformer-tourism-monthly)
|
||||||
architecture.
|
architecture.
|
||||||
|
|
||||||
@ -136,47 +134,47 @@ class InformerConfig(PretrainedConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_size: int = 1,
|
input_size: int = 1,
|
||||||
prediction_length: Optional[int] = None,
|
prediction_length: Optional[int] = None,
|
||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
distribution_output: str = "student_t",
|
distribution_output: str = "student_t",
|
||||||
loss: str = "nll",
|
loss: str = "nll",
|
||||||
lags_sequence: List[int] = None,
|
lags_sequence: List[int] = None,
|
||||||
scaling: bool = True,
|
scaling: bool = True,
|
||||||
num_dynamic_real_features: int = 0,
|
num_dynamic_real_features: int = 0,
|
||||||
num_static_real_features: int = 0,
|
num_static_real_features: int = 0,
|
||||||
num_static_categorical_features: int = 0,
|
num_static_categorical_features: int = 0,
|
||||||
num_time_features: int = 0,
|
num_time_features: int = 0,
|
||||||
cardinality: Optional[List[int]] = None,
|
cardinality: Optional[List[int]] = None,
|
||||||
embedding_dimension: Optional[List[int]] = None,
|
embedding_dimension: Optional[List[int]] = None,
|
||||||
encoder_ffn_dim: int = 32,
|
encoder_ffn_dim: int = 32,
|
||||||
decoder_ffn_dim: int = 32,
|
decoder_ffn_dim: int = 32,
|
||||||
encoder_attention_heads: int = 2,
|
encoder_attention_heads: int = 2,
|
||||||
decoder_attention_heads: int = 2,
|
decoder_attention_heads: int = 2,
|
||||||
encoder_layers: int = 2,
|
encoder_layers: int = 2,
|
||||||
decoder_layers: int = 2,
|
decoder_layers: int = 2,
|
||||||
is_encoder_decoder: bool = True,
|
is_encoder_decoder: bool = True,
|
||||||
activation_function: str = "gelu",
|
activation_function: str = "gelu",
|
||||||
dropout: float = 0.05,
|
dropout: float = 0.05,
|
||||||
encoder_layerdrop: float = 0.1,
|
encoder_layerdrop: float = 0.1,
|
||||||
decoder_layerdrop: float = 0.1,
|
decoder_layerdrop: float = 0.1,
|
||||||
attention_dropout: float = 0.1,
|
attention_dropout: float = 0.1,
|
||||||
activation_dropout: float = 0.1,
|
activation_dropout: float = 0.1,
|
||||||
num_parallel_samples: int = 100,
|
num_parallel_samples: int = 100,
|
||||||
init_std: float = 0.02,
|
init_std: float = 0.02,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
# Informer arguments
|
# Informer arguments
|
||||||
attn: str = "prob",
|
attn: str = "prob",
|
||||||
factor: int = 5,
|
factor: int = 5,
|
||||||
distil: bool = True,
|
distil: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# time series specific configuration
|
# time series specific configuration
|
||||||
self.prediction_length = prediction_length
|
self.prediction_length = prediction_length
|
||||||
self.context_length = context_length or prediction_length
|
self.context_length = context_length or prediction_length
|
||||||
self.distribution_output = distribution_output
|
self.distribution_output = distribution_output
|
||||||
self.loss = loss # Eli: From vanilla ts transformer
|
self.loss = loss # Eli: From vanilla ts transformer
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.num_time_features = num_time_features
|
self.num_time_features = num_time_features
|
||||||
self.lags_sequence = lags_sequence
|
self.lags_sequence = lags_sequence
|
||||||
|
@ -17,9 +17,12 @@
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from math import sqrt
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributions import (
|
from torch.distributions import (
|
||||||
AffineTransform,
|
AffineTransform,
|
||||||
@ -37,11 +40,6 @@ from ...modeling_utils import PreTrainedModel
|
|||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_informer import InformerConfig
|
from .configuration_informer import InformerConfig
|
||||||
|
|
||||||
from math import sqrt
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -54,7 +52,6 @@ INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AffineTransformed(TransformedDistribution):
|
class AffineTransformed(TransformedDistribution):
|
||||||
def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0):
|
def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0):
|
||||||
self.scale = 1.0 if scale is None else scale
|
self.scale = 1.0 if scale is None else scale
|
||||||
@ -472,6 +469,7 @@ class Seq2SeqTimeSeriesModelOutput(ModelOutput):
|
|||||||
scale: Optional[torch.FloatTensor] = None
|
scale: Optional[torch.FloatTensor] = None
|
||||||
static_features: Optional[torch.FloatTensor] = None
|
static_features: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer
|
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seq2SeqTimeSeriesPredictionOutput(ModelOutput):
|
class Seq2SeqTimeSeriesPredictionOutput(ModelOutput):
|
||||||
@ -540,6 +538,7 @@ class Seq2SeqTimeSeriesPredictionOutput(ModelOutput):
|
|||||||
scale: Optional[torch.FloatTensor] = None
|
scale: Optional[torch.FloatTensor] = None
|
||||||
static_features: Optional[torch.FloatTensor] = None
|
static_features: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer
|
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer
|
||||||
@dataclass
|
@dataclass
|
||||||
class SampleTimeSeriesPredictionOutput(ModelOutput):
|
class SampleTimeSeriesPredictionOutput(ModelOutput):
|
||||||
@ -554,9 +553,7 @@ class TriangularCausalMask:
|
|||||||
def __init__(self, B, L, device="cpu"):
|
def __init__(self, B, L, device="cpu"):
|
||||||
mask_shape = [B, 1, L, L]
|
mask_shape = [B, 1, L, L]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self._mask = torch.triu(
|
self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
|
||||||
torch.ones(mask_shape, dtype=torch.bool), diagonal=1
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mask(self):
|
def mask(self):
|
||||||
@ -568,9 +565,7 @@ class ProbMask:
|
|||||||
def __init__(self, B, H, L, index, scores, device="cpu"):
|
def __init__(self, B, H, L, index, scores, device="cpu"):
|
||||||
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
|
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
|
||||||
_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
|
_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
|
||||||
indicator = _mask_ex[
|
indicator = _mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :].to(device)
|
||||||
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
|
||||||
].to(device)
|
|
||||||
self._mask = indicator.view(scores.shape).to(device)
|
self._mask = indicator.view(scores.shape).to(device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -597,7 +592,7 @@ class FullAttention(nn.Module):
|
|||||||
def forward(self, queries, keys, values, attn_mask):
|
def forward(self, queries, keys, values, attn_mask):
|
||||||
B, L, H, E = queries.shape
|
B, L, H, E = queries.shape
|
||||||
_, S, _, D = values.shape
|
_, S, _, D = values.shape
|
||||||
scale = self.scale or 1. / sqrt(E)
|
scale = self.scale or 1.0 / sqrt(E)
|
||||||
|
|
||||||
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
||||||
if self.mask_flag:
|
if self.mask_flag:
|
||||||
@ -673,14 +668,12 @@ class ProbAttention(nn.Module):
|
|||||||
|
|
||||||
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
|
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
|
||||||
|
|
||||||
context_in[
|
context_in[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = torch.matmul(
|
||||||
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
attn, V
|
||||||
] = torch.matmul(attn, V).type_as(context_in)
|
).type_as(context_in)
|
||||||
if self.output_attention:
|
if self.output_attention:
|
||||||
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
|
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
|
||||||
attns[
|
attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
|
||||||
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
|
||||||
] = attn
|
|
||||||
return (context_in, attns)
|
return (context_in, attns)
|
||||||
else:
|
else:
|
||||||
return (context_in, None)
|
return (context_in, None)
|
||||||
@ -708,18 +701,14 @@ class ProbAttention(nn.Module):
|
|||||||
# get the context
|
# get the context
|
||||||
context = self._get_initial_context(values, L_Q)
|
context = self._get_initial_context(values, L_Q)
|
||||||
# update the context with selected top_k queries
|
# update the context with selected top_k queries
|
||||||
context, attn = self._update_context(
|
context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
|
||||||
context, values, scores_top, index, L_Q, attn_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
return context.transpose(2, 1).contiguous(), attn
|
return context.transpose(2, 1).contiguous(), attn
|
||||||
|
|
||||||
|
|
||||||
# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/attn.py
|
# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/attn.py
|
||||||
class AttentionLayer(nn.Module):
|
class AttentionLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None, mix=False):
|
||||||
self, attention, d_model, n_heads, d_keys=None, d_values=None, mix=False
|
|
||||||
):
|
|
||||||
super(AttentionLayer, self).__init__()
|
super(AttentionLayer, self).__init__()
|
||||||
|
|
||||||
d_keys = d_keys or (d_model // n_heads)
|
d_keys = d_keys or (d_model // n_heads)
|
||||||
@ -761,13 +750,13 @@ class ConvLayer(nn.Module):
|
|||||||
padding=1,
|
padding=1,
|
||||||
padding_mode="circular",
|
padding_mode="circular",
|
||||||
)
|
)
|
||||||
self.norm = nn.BatchNorm1d(c_in) # Eli question: why batchnorm here?
|
self.norm = nn.BatchNorm1d(c_in) # Eli question: why batchnorm here?
|
||||||
self.activation = nn.ELU()
|
self.activation = nn.ELU()
|
||||||
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.downConv(x.permute(0, 2, 1))
|
x = self.downConv(x.permute(0, 2, 1))
|
||||||
x = self.norm(x) # Eli: why? maybe because the impl...
|
x = self.norm(x) # Eli: why? maybe because the impl...
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.maxPool(x)
|
x = self.maxPool(x)
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
@ -830,9 +819,7 @@ class DecoderLayer(nn.Module):
|
|||||||
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
|
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
|
|
||||||
x = x + self.dropout(
|
x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0])
|
||||||
self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
y = x = self.norm2(x)
|
y = x = self.norm2(x)
|
||||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||||
@ -847,8 +834,9 @@ class InformerEncoder(nn.Module):
|
|||||||
|
|
||||||
self.activation_fn = ACT2FN[config.activation_function]
|
self.activation_fn = ACT2FN[config.activation_function]
|
||||||
Attn = ProbAttention if config.attn == "prob" else FullAttention
|
Attn = ProbAttention if config.attn == "prob" else FullAttention
|
||||||
self.attn_layers = nn.ModuleList([
|
self.attn_layers = nn.ModuleList(
|
||||||
EncoderLayer(
|
[
|
||||||
|
EncoderLayer(
|
||||||
AttentionLayer(
|
AttentionLayer(
|
||||||
Attn(
|
Attn(
|
||||||
mask_flag=False,
|
mask_flag=False,
|
||||||
@ -864,8 +852,10 @@ class InformerEncoder(nn.Module):
|
|||||||
d_ff=config.encoder_ffn_dim,
|
d_ff=config.encoder_ffn_dim,
|
||||||
dropout=config.attention_dropout,
|
dropout=config.attention_dropout,
|
||||||
activation=self.activation_fn,
|
activation=self.activation_fn,
|
||||||
) for _ in range(config.encoder_layers)
|
)
|
||||||
])
|
for _ in range(config.encoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if config.distil is not None:
|
if config.distil is not None:
|
||||||
self.conv_layers = nn.ModuleList([ConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)])
|
self.conv_layers = nn.ModuleList([ConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)])
|
||||||
@ -1000,22 +990,15 @@ class InformerModel(InformerPreTrainedModel):
|
|||||||
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Returns lagged subsequences of a given sequence.
|
Returns lagged subsequences of a given sequence. Parameters ---------- sequence : Tensor
|
||||||
Parameters
|
the sequence from which lagged subsequences should be extracted. Shape: (N, T, C).
|
||||||
----------
|
|
||||||
sequence : Tensor
|
|
||||||
the sequence from which lagged subsequences should be extracted.
|
|
||||||
Shape: (N, T, C).
|
|
||||||
subsequences_length : int
|
subsequences_length : int
|
||||||
length of the subsequences to be extracted.
|
length of the subsequences to be extracted.
|
||||||
shift: int
|
shift: int
|
||||||
shift the lags by this amount back.
|
shift the lags by this amount back.
|
||||||
Returns
|
Returns -------- lagged : Tensor
|
||||||
--------
|
a tensor of shape (N, S, C, I), where S = subsequences_length and I = len(indices), containing lagged
|
||||||
lagged : Tensor
|
subsequences. Specifically, lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
||||||
a tensor of shape (N, S, C, I), where S = subsequences_length and
|
|
||||||
I = len(indices), containing lagged subsequences. Specifically,
|
|
||||||
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
|
||||||
"""
|
"""
|
||||||
sequence_length = sequence.shape[1]
|
sequence_length = sequence.shape[1]
|
||||||
indices = [lag - shift for lag in self.config.lags_sequence]
|
indices = [lag - shift for lag in self.config.lags_sequence]
|
||||||
@ -1125,24 +1108,24 @@ class InformerModel(InformerPreTrainedModel):
|
|||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
past_values: torch.Tensor,
|
past_values: torch.Tensor,
|
||||||
past_time_features: torch.Tensor,
|
past_time_features: torch.Tensor,
|
||||||
past_observed_mask: torch.Tensor,
|
past_observed_mask: torch.Tensor,
|
||||||
static_categorical_features: torch.Tensor,
|
static_categorical_features: torch.Tensor,
|
||||||
static_real_features: torch.Tensor,
|
static_real_features: torch.Tensor,
|
||||||
future_values: Optional[torch.Tensor] = None,
|
future_values: Optional[torch.Tensor] = None,
|
||||||
future_time_features: Optional[torch.Tensor] = None,
|
future_time_features: Optional[torch.Tensor] = None,
|
||||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Seq2SeqTimeSeriesModelOutput, Tuple]:
|
) -> Union[Seq2SeqTimeSeriesModelOutput, Tuple]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -1178,7 +1161,7 @@ class InformerModel(InformerPreTrainedModel):
|
|||||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
dec_input = transformer_inputs[:, self.config.context_length:, ...]
|
dec_input = transformer_inputs[:, self.config.context_length :, ...]
|
||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
inputs_embeds=dec_input,
|
inputs_embeds=dec_input,
|
||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
@ -1462,6 +1445,3 @@ class InformerForPrediction(InformerPreTrainedModel):
|
|||||||
(-1, num_parallel_samples, self.config.prediction_length) + self.target_shape,
|
(-1, num_parallel_samples, self.config.prediction_length) + self.target_shape,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,15 +31,8 @@ TOLERANCE = 1e-4
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import InformerConfig, InformerForPrediction, InformerModel
|
||||||
InformerConfig,
|
from transformers.models.informer.modeling_informer import InformerDecoder, InformerEncoder
|
||||||
InformerForPrediction,
|
|
||||||
InformerModel,
|
|
||||||
)
|
|
||||||
from transformers.models.informer.modeling_informer import (
|
|
||||||
InformerDecoder,
|
|
||||||
InformerEncoder,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@ -171,9 +164,7 @@ class InformerModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class InformerModelTest(ModelTesterMixin, unittest.TestCase):
|
class InformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (InformerModel, InformerForPrediction) if is_torch_available() else ()
|
||||||
(InformerModel, InformerForPrediction) if is_torch_available() else ()
|
|
||||||
)
|
|
||||||
all_generative_model_classes = (InformerForPrediction,) if is_torch_available() else ()
|
all_generative_model_classes = (InformerForPrediction,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
@ -374,9 +365,7 @@ def prepare_batch(filename="train-batch.pt"):
|
|||||||
@slow
|
@slow
|
||||||
class InformerModelIntegrationTests(unittest.TestCase):
|
class InformerModelIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
model = InformerModel.from_pretrained("huggingface/time-series-transformer-tourism-monthly").to(
|
model = InformerModel.from_pretrained("huggingface/time-series-transformer-tourism-monthly").to(torch_device)
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
batch = prepare_batch()
|
batch = prepare_batch()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -399,9 +388,9 @@ class InformerModelIntegrationTests(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
|
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
|
|
||||||
def test_inference_head(self):
|
def test_inference_head(self):
|
||||||
model = InformerForPrediction.from_pretrained(
|
model = InformerForPrediction.from_pretrained("huggingface/time-series-transformer-tourism-monthly").to(
|
||||||
"huggingface/time-series-transformer-tourism-monthly"
|
torch_device
|
||||||
).to(torch_device)
|
)
|
||||||
batch = prepare_batch("val-batch.pt")
|
batch = prepare_batch("val-batch.pt")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(
|
output = model(
|
||||||
@ -421,9 +410,9 @@ class InformerModelIntegrationTests(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
|
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
|
|
||||||
def test_seq_to_seq_generation(self):
|
def test_seq_to_seq_generation(self):
|
||||||
model = InformerForPrediction.from_pretrained(
|
model = InformerForPrediction.from_pretrained("huggingface/time-series-transformer-tourism-monthly").to(
|
||||||
"huggingface/time-series-transformer-tourism-monthly"
|
torch_device
|
||||||
).to(torch_device)
|
)
|
||||||
batch = prepare_batch("val-batch.pt")
|
batch = prepare_batch("val-batch.pt")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
|
Loading…
Reference in New Issue
Block a user