mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix(PatchTST): Wrong dropout used for PretainHead (#31117)
* fix(PatchTST): Wrong dropout used for PretainHead * feat(PatchTST): remove unused config.dropout --------- Co-authored-by: Strobel Maximilian (IFAG PSS SIS SCE ACM) <Maximilian.Strobel@infineon.com>
This commit is contained in:
parent
e83cf58145
commit
36ade4a32b
@ -67,8 +67,6 @@ class PatchTSTConfig(PretrainedConfig):
|
||||
A value added to the denominator for numerical stability of normalization.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for the attention probabilities.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the Transformer.
|
||||
positional_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability in the positional embedding layer.
|
||||
path_dropout (`float`, *optional*, defaults to 0.0):
|
||||
@ -167,7 +165,6 @@ class PatchTSTConfig(PretrainedConfig):
|
||||
norm_type: str = "batchnorm",
|
||||
norm_eps: float = 1e-05,
|
||||
attention_dropout: float = 0.0,
|
||||
dropout: float = 0.0,
|
||||
positional_dropout: float = 0.0,
|
||||
path_dropout: float = 0.0,
|
||||
ff_dropout: float = 0.0,
|
||||
@ -209,7 +206,6 @@ class PatchTSTConfig(PretrainedConfig):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.share_embedding = share_embedding
|
||||
self.channel_attention = channel_attention
|
||||
|
@ -1262,7 +1262,7 @@ class PatchTSTMaskPretrainHead(nn.Module):
|
||||
|
||||
def __init__(self, config: PatchTSTConfig):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
|
||||
self.linear = nn.Linear(config.d_model, config.patch_length)
|
||||
self.use_cls_token = config.use_cls_token
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user