diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 6daec258b6e..9cf24f31e5f 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -382,7 +382,7 @@ class ASTPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True - # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights + # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 646b6b9a2a9..176ba012448 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -387,7 +387,6 @@ class DeiTEncoder(nn.Module): ) -# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->DeiT all-casing class DeiTPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 1b937750c0e..5cf09889ca8 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -67,21 +67,11 @@ class ViTEmbeddings(nn.Module): def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: super().__init__() - self.cls_token = nn.Parameter( - nn.init.trunc_normal_( - torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range - ) - ) + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.patch_embeddings = ViTPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter( - nn.init.trunc_normal_( - torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32), - mean=0.0, - std=config.initializer_range, - ) - ) + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config @@ -461,6 +451,18 @@ class ViTPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, ViTEmbeddings): + nn.init.trunc_normal_( + module.position_embeddings, + mean=0.0, + std=self.config.initializer_range, + ) + + nn.init.trunc_normal_( + module.cls_token, + mean=0.0, + std=self.config.initializer_range, + ) def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: if isinstance(module, ViTEncoder): diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 8aa0019c72a..8517f0f95bf 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -59,24 +59,15 @@ class ViTHybridEmbeddings(nn.Module): Construct the CLS token, position and patch embeddings. Optionally, also the mask token. """ + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.__init__ with ViT->ViTHybrid def __init__(self, config: ViTHybridConfig, use_mask_token: bool = False) -> None: super().__init__() - self.cls_token = nn.Parameter( - nn.init.trunc_normal_( - torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range - ) - ) + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.patch_embeddings = ViTHybridPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter( - nn.init.trunc_normal_( - torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32), - mean=0.0, - std=config.initializer_range, - ) - ) + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config @@ -485,6 +476,18 @@ class ViTHybridPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, ViTHybridEmbeddings): + nn.init.trunc_normal_( + module.position_embeddings, + mean=0.0, + std=self.config.initializer_range, + ) + + nn.init.trunc_normal_( + module.cls_token, + mean=0.0, + std=self.config.initializer_range, + ) def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: if isinstance(module, ViTHybridEncoder):