fix(PatchTST): Wrong dropout used for PretainHead (#31117)

* fix(PatchTST): Wrong dropout used for PretainHead

* feat(PatchTST): remove unused config.dropout

---------

Co-authored-by: Strobel Maximilian (IFAG PSS SIS SCE ACM) <Maximilian.Strobel@infineon.com>
This commit is contained in:
Max Strobel 2024-06-04 11:11:36 +02:00 committed by GitHub
parent e83cf58145
commit 36ade4a32b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 1 additions and 5 deletions

View File

@ -67,8 +67,6 @@ class PatchTSTConfig(PretrainedConfig):
A value added to the denominator for numerical stability of normalization.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention probabilities.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the Transformer.
positional_dropout (`float`, *optional*, defaults to 0.0):
The dropout probability in the positional embedding layer.
path_dropout (`float`, *optional*, defaults to 0.0):
@ -167,7 +165,6 @@ class PatchTSTConfig(PretrainedConfig):
norm_type: str = "batchnorm",
norm_eps: float = 1e-05,
attention_dropout: float = 0.0,
dropout: float = 0.0,
positional_dropout: float = 0.0,
path_dropout: float = 0.0,
ff_dropout: float = 0.0,
@ -209,7 +206,6 @@ class PatchTSTConfig(PretrainedConfig):
self.num_attention_heads = num_attention_heads
self.ffn_dim = ffn_dim
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.attention_dropout = attention_dropout
self.share_embedding = share_embedding
self.channel_attention = channel_attention

View File

@ -1262,7 +1262,7 @@ class PatchTSTMaskPretrainHead(nn.Module):
def __init__(self, config: PatchTSTConfig):
super().__init__()
self.dropout = nn.Dropout(config.dropout)
self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
self.linear = nn.Linear(config.d_model, config.patch_length)
self.use_cls_token = config.use_cls_token