From 36ade4a32b2cc47f19fce2987065082a696f89a5 Mon Sep 17 00:00:00 2001 From: Max Strobel Date: Tue, 4 Jun 2024 11:11:36 +0200 Subject: [PATCH] fix(PatchTST): Wrong dropout used for PretainHead (#31117) * fix(PatchTST): Wrong dropout used for PretainHead * feat(PatchTST): remove unused config.dropout --------- Co-authored-by: Strobel Maximilian (IFAG PSS SIS SCE ACM) --- src/transformers/models/patchtst/configuration_patchtst.py | 4 ---- src/transformers/models/patchtst/modeling_patchtst.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/models/patchtst/configuration_patchtst.py b/src/transformers/models/patchtst/configuration_patchtst.py index acae3d0dc60..29d14491752 100644 --- a/src/transformers/models/patchtst/configuration_patchtst.py +++ b/src/transformers/models/patchtst/configuration_patchtst.py @@ -67,8 +67,6 @@ class PatchTSTConfig(PretrainedConfig): A value added to the denominator for numerical stability of normalization. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout probability for the attention probabilities. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability for all fully connected layers in the Transformer. positional_dropout (`float`, *optional*, defaults to 0.0): The dropout probability in the positional embedding layer. path_dropout (`float`, *optional*, defaults to 0.0): @@ -167,7 +165,6 @@ class PatchTSTConfig(PretrainedConfig): norm_type: str = "batchnorm", norm_eps: float = 1e-05, attention_dropout: float = 0.0, - dropout: float = 0.0, positional_dropout: float = 0.0, path_dropout: float = 0.0, ff_dropout: float = 0.0, @@ -209,7 +206,6 @@ class PatchTSTConfig(PretrainedConfig): self.num_attention_heads = num_attention_heads self.ffn_dim = ffn_dim self.num_hidden_layers = num_hidden_layers - self.dropout = dropout self.attention_dropout = attention_dropout self.share_embedding = share_embedding self.channel_attention = channel_attention diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index e30e4572834..3c761bcae77 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -1262,7 +1262,7 @@ class PatchTSTMaskPretrainHead(nn.Module): def __init__(self, config: PatchTSTConfig): super().__init__() - self.dropout = nn.Dropout(config.dropout) + self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity() self.linear = nn.Linear(config.d_model, config.patch_length) self.use_cls_token = config.use_cls_token