mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
🚨 🚨 🚨 Fix ViT parameter initialization (#19341)
This PR aims to rectify the discrepancy between the training performances of HF and Timm ViT implementations. - Initializes torch and flax ViT dense layer weights with trunc_normal instead of normal (consistent with the TF implementation. - Initializes cls_token and positional_embeddings with trunc_normal - Updates DeiT copy to reflect the changes
This commit is contained in:
parent
7e7f62bfa7
commit
f0b490151e
@ -402,9 +402,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
|
@ -101,7 +101,9 @@ class FlaxViTPatchEmbeddings(nn.Module):
|
||||
strides=(patch_size, patch_size),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
kernel_init=jax.nn.initializers.variance_scaling(
|
||||
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, pixel_values):
|
||||
@ -122,11 +124,17 @@ class FlaxViTEmbeddings(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
|
||||
self.cls_token = self.param(
|
||||
"cls_token",
|
||||
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
|
||||
(1, 1, self.config.hidden_size),
|
||||
)
|
||||
self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = self.param(
|
||||
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
|
||||
"position_embeddings",
|
||||
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
|
||||
(1, num_patches + 1, self.config.hidden_size),
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
@ -156,19 +164,25 @@ class FlaxViTSelfAttention(nn.Module):
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
kernel_init=jax.nn.initializers.variance_scaling(
|
||||
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
|
||||
),
|
||||
use_bias=self.config.qkv_bias,
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
kernel_init=jax.nn.initializers.variance_scaling(
|
||||
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
|
||||
),
|
||||
use_bias=self.config.qkv_bias,
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
kernel_init=jax.nn.initializers.variance_scaling(
|
||||
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
|
||||
),
|
||||
use_bias=self.config.qkv_bias,
|
||||
)
|
||||
|
||||
@ -214,7 +228,9 @@ class FlaxViTSelfOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
kernel_init=jax.nn.initializers.variance_scaling(
|
||||
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
||||
),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@ -253,7 +269,9 @@ class FlaxViTIntermediate(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
kernel_init=jax.nn.initializers.variance_scaling(
|
||||
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
||||
),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@ -271,7 +289,9 @@ class FlaxViTOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
kernel_init=jax.nn.initializers.variance_scaling(
|
||||
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
||||
),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@ -394,7 +414,9 @@ class FlaxViTPooler(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
kernel_init=jax.nn.initializers.variance_scaling(
|
||||
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
||||
),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@ -572,7 +594,9 @@ class FlaxViTForImageClassificationModule(nn.Module):
|
||||
self.classifier = nn.Dense(
|
||||
self.config.num_labels,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
kernel_init=jax.nn.initializers.variance_scaling(
|
||||
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
@ -69,11 +69,14 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
|
||||
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.cls_token = self.add_weight(
|
||||
shape=(1, 1, self.config.hidden_size), initializer="zeros", trainable=True, name="cls_token"
|
||||
shape=(1, 1, self.config.hidden_size),
|
||||
initializer=get_initializer(self.config.initializer_range),
|
||||
trainable=True,
|
||||
name="cls_token",
|
||||
)
|
||||
self.position_embeddings = self.add_weight(
|
||||
shape=(1, num_patches + 1, self.config.hidden_size),
|
||||
initializer="zeros",
|
||||
initializer=get_initializer(self.config.initializer_range),
|
||||
trainable=True,
|
||||
name="position_embeddings",
|
||||
)
|
||||
|
@ -67,11 +67,17 @@ class ViTEmbeddings(nn.Module):
|
||||
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.cls_token = nn.Parameter(
|
||||
nn.init.trunc_normal_(torch.zeros(1, 1, config.hidden_size), mean=0.0, std=config.initializer_range)
|
||||
)
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
||||
self.patch_embeddings = ViTPatchEmbeddings(config)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||
self.position_embeddings = nn.Parameter(
|
||||
nn.init.trunc_normal_(
|
||||
torch.zeros(1, num_patches + 1, config.hidden_size), mean=0.0, std=config.initializer_range
|
||||
)
|
||||
)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.config = config
|
||||
|
||||
@ -440,9 +446,7 @@ class ViTPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
|
@ -581,7 +581,6 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
|
Loading…
Reference in New Issue
Block a user