🚨 🚨 🚨 Fix ViT parameter initialization (#19341)

This PR aims to rectify the discrepancy between the training performances of HF and Timm ViT implementations.

- Initializes torch and flax ViT dense layer weights with trunc_normal instead of normal (consistent with the TF implementation.
- Initializes cls_token and positional_embeddings with trunc_normal
- Updates DeiT copy to reflect the changes
This commit is contained in:
Alara Dirik 2022-10-06 12:04:01 +03:00 committed by GitHub
parent 7e7f62bfa7
commit f0b490151e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 50 additions and 22 deletions

View File

@ -402,9 +402,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):

View File

@ -101,7 +101,9 @@ class FlaxViTPatchEmbeddings(nn.Module):
strides=(patch_size, patch_size),
padding="VALID",
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
)
def __call__(self, pixel_values):
@ -122,11 +124,17 @@ class FlaxViTEmbeddings(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
self.cls_token = self.param(
"cls_token",
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
(1, 1, self.config.hidden_size),
)
self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = self.param(
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
"position_embeddings",
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
(1, num_patches + 1, self.config.hidden_size),
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
@ -156,19 +164,25 @@ class FlaxViTSelfAttention(nn.Module):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
),
use_bias=self.config.qkv_bias,
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
),
use_bias=self.config.qkv_bias,
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
),
use_bias=self.config.qkv_bias,
)
@ -214,7 +228,9 @@ class FlaxViTSelfOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
@ -253,7 +269,9 @@ class FlaxViTIntermediate(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
@ -271,7 +289,9 @@ class FlaxViTOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
@ -394,7 +414,9 @@ class FlaxViTPooler(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
dtype=self.dtype,
)
@ -572,7 +594,9 @@ class FlaxViTForImageClassificationModule(nn.Module):
self.classifier = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range**2, "fan_in", "truncated_normal"
),
)
def __call__(

View File

@ -69,11 +69,14 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
num_patches = self.patch_embeddings.num_patches
self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size), initializer="zeros", trainable=True, name="cls_token"
shape=(1, 1, self.config.hidden_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
name="cls_token",
)
self.position_embeddings = self.add_weight(
shape=(1, num_patches + 1, self.config.hidden_size),
initializer="zeros",
initializer=get_initializer(self.config.initializer_range),
trainable=True,
name="position_embeddings",
)

View File

@ -67,11 +67,17 @@ class ViTEmbeddings(nn.Module):
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.cls_token = nn.Parameter(
nn.init.trunc_normal_(torch.zeros(1, 1, config.hidden_size), mean=0.0, std=config.initializer_range)
)
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
self.position_embeddings = nn.Parameter(
nn.init.trunc_normal_(
torch.zeros(1, num_patches + 1, config.hidden_size), mean=0.0, std=config.initializer_range
)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
@ -440,9 +446,7 @@ class ViTPreTrainedModel(PreTrainedModel):
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):

View File

@ -581,7 +581,6 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):