mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Vision
] [Refactor] Initialize weights on the correct place (#20803)
* fix nit - initialization on `_init_weights` - fix copies * add copied from
This commit is contained in:
parent
6b5a8f83ce
commit
ecd7de3dff
@ -382,7 +382,7 @@ class ASTPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "input_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights
|
||||
# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
|
@ -387,7 +387,6 @@ class DeiTEncoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->DeiT all-casing
|
||||
class DeiTPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
|
@ -67,21 +67,11 @@ class ViTEmbeddings(nn.Module):
|
||||
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
nn.init.trunc_normal_(
|
||||
torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range
|
||||
)
|
||||
)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
||||
self.patch_embeddings = ViTPatchEmbeddings(config)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(
|
||||
nn.init.trunc_normal_(
|
||||
torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32),
|
||||
mean=0.0,
|
||||
std=config.initializer_range,
|
||||
)
|
||||
)
|
||||
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.config = config
|
||||
|
||||
@ -461,6 +451,18 @@ class ViTPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, ViTEmbeddings):
|
||||
nn.init.trunc_normal_(
|
||||
module.position_embeddings,
|
||||
mean=0.0,
|
||||
std=self.config.initializer_range,
|
||||
)
|
||||
|
||||
nn.init.trunc_normal_(
|
||||
module.cls_token,
|
||||
mean=0.0,
|
||||
std=self.config.initializer_range,
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None:
|
||||
if isinstance(module, ViTEncoder):
|
||||
|
@ -59,24 +59,15 @@ class ViTHybridEmbeddings(nn.Module):
|
||||
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.__init__ with ViT->ViTHybrid
|
||||
def __init__(self, config: ViTHybridConfig, use_mask_token: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
nn.init.trunc_normal_(
|
||||
torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range
|
||||
)
|
||||
)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
||||
self.patch_embeddings = ViTHybridPatchEmbeddings(config)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(
|
||||
nn.init.trunc_normal_(
|
||||
torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32),
|
||||
mean=0.0,
|
||||
std=config.initializer_range,
|
||||
)
|
||||
)
|
||||
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.config = config
|
||||
|
||||
@ -485,6 +476,18 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, ViTHybridEmbeddings):
|
||||
nn.init.trunc_normal_(
|
||||
module.position_embeddings,
|
||||
mean=0.0,
|
||||
std=self.config.initializer_range,
|
||||
)
|
||||
|
||||
nn.init.trunc_normal_(
|
||||
module.cls_token,
|
||||
mean=0.0,
|
||||
std=self.config.initializer_range,
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None:
|
||||
if isinstance(module, ViTHybridEncoder):
|
||||
|
Loading…
Reference in New Issue
Block a user