From 4ed075280c30051126ba6e8d1634867abeb0fbc4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Jun 2023 18:46:05 +0200 Subject: [PATCH] [Time Series] use mean scaler when scaling is a boolean True (#24237) * use mean scaler when scaling is boolean True * remove debug --- src/transformers/models/autoformer/modeling_autoformer.py | 2 +- src/transformers/models/informer/modeling_informer.py | 2 +- .../time_series_transformer/modeling_time_series_transformer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index a77920fb9d6..3e482cadf67 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1495,7 +1495,7 @@ class AutoformerModel(AutoformerPreTrainedModel): def __init__(self, config: AutoformerConfig): super().__init__(config) - if config.scaling == "mean" or config.scaling: + if config.scaling == "mean" or config.scaling is True: self.scaler = AutoformerMeanScaler(dim=1, keepdim=True) elif config.scaling == "std": self.scaler = AutoformerStdScaler(dim=1, keepdim=True) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 4c8edcbc156..2bf3f208a90 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1504,7 +1504,7 @@ class InformerModel(InformerPreTrainedModel): def __init__(self, config: InformerConfig): super().__init__(config) - if config.scaling == "mean" or config.scaling: + if config.scaling == "mean" or config.scaling is True: self.scaler = InformerMeanScaler(dim=1, keepdim=True) elif config.scaling == "std": self.scaler = InformerStdScaler(dim=1, keepdim=True) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index d5ffa069d95..8986ef6729c 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -1229,7 +1229,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): def __init__(self, config: TimeSeriesTransformerConfig): super().__init__(config) - if config.scaling == "mean" or config.scaling: + if config.scaling == "mean" or config.scaling is True: self.scaler = TimeSeriesMeanScaler(dim=1, keepdim=True) elif config.scaling == "std": self.scaler = TimeSeriesStdScaler(dim=1, keepdim=True)