[Vision] [Refactor] Initialize weights on the correct place (#20803)

* fix nit

- initialization on `_init_weights`
- fix copies

* add copied from
This commit is contained in:
Younes Belkada 2022-12-19 10:37:14 +01:00 committed by GitHub
parent 6b5a8f83ce
commit ecd7de3dff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 26 deletions

View File

@ -382,7 +382,7 @@ class ASTPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights
# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):

View File

@ -387,7 +387,6 @@ class DeiTEncoder(nn.Module):
)
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->DeiT all-casing
class DeiTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained

View File

@ -67,21 +67,11 @@ class ViTEmbeddings(nn.Module):
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(
nn.init.trunc_normal_(
torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range
)
)
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(
nn.init.trunc_normal_(
torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32),
mean=0.0,
std=config.initializer_range,
)
)
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
@ -461,6 +451,18 @@ class ViTPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings):
nn.init.trunc_normal_(
module.position_embeddings,
mean=0.0,
std=self.config.initializer_range,
)
nn.init.trunc_normal_(
module.cls_token,
mean=0.0,
std=self.config.initializer_range,
)
def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None:
if isinstance(module, ViTEncoder):

View File

@ -59,24 +59,15 @@ class ViTHybridEmbeddings(nn.Module):
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
"""
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.__init__ with ViT->ViTHybrid
def __init__(self, config: ViTHybridConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(
nn.init.trunc_normal_(
torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range
)
)
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTHybridPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(
nn.init.trunc_normal_(
torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32),
mean=0.0,
std=config.initializer_range,
)
)
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
@ -485,6 +476,18 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTHybridEmbeddings):
nn.init.trunc_normal_(
module.position_embeddings,
mean=0.0,
std=self.config.initializer_range,
)
nn.init.trunc_normal_(
module.cls_token,
mean=0.0,
std=self.config.initializer_range,
)
def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None:
if isinstance(module, ViTHybridEncoder):