Fix momentum and epsilon values (#19454)

The momentum value for PyTorch and TensorFlow batch normalization layers is not equivalent. The TensorFlow value should be (1 - pytorch_momentum) in order to ensure the correct updates are applied to the running mean and running variance calculations. We wouldn't observe a difference loading a pretrained model and performing inference, but evaluation outputs would change after some training steps.
This commit is contained in:
amyeroberts 2022-10-10 15:17:41 +01:00 committed by GitHub
parent b0b962ccca
commit 4dd784c32f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 9 additions and 9 deletions

View File

@ -1041,7 +1041,7 @@ class TFData2VecVisionConvModule(tf.keras.layers.Layer):
dilation_rate=dilation,
name="conv",
)
self.bn = tf.keras.layers.BatchNormalization(name="bn")
self.bn = tf.keras.layers.BatchNormalization(name="bn", momentum=0.9, epsilon=1e-5)
self.activation = tf.nn.relu
def call(self, input: tf.Tensor) -> tf.Tensor:
@ -1331,7 +1331,7 @@ class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
# FPNs
self.fpn1 = [
tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"),
tf.keras.layers.BatchNormalization(name="fpn1.1"),
tf.keras.layers.BatchNormalization(name="fpn1.1", momentum=0.9, epsilon=1e-5),
tf.keras.layers.Activation("gelu"),
tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"),
]

View File

@ -1253,13 +1253,13 @@ class TFGroupViTMainLayer(tf.keras.layers.Layer):
self.visual_projection = [
tf.keras.layers.Dense(self.projection_intermediate_dim, name="visual_projection.0"),
tf.keras.layers.BatchNormalization(name="visual_projection.1", momentum=0.1, epsilon=1e-5),
tf.keras.layers.BatchNormalization(name="visual_projection.1", momentum=0.9, epsilon=1e-5),
tf.keras.layers.ReLU(name="visual_projection.2"),
tf.keras.layers.Dense(self.projection_dim, name="visual_projection.3"),
]
self.text_projection = [
tf.keras.layers.Dense(self.projection_intermediate_dim, name="text_projection.0"),
tf.keras.layers.BatchNormalization(name="text_projection.1", momentum=0.1, epsilon=1e-5),
tf.keras.layers.BatchNormalization(name="text_projection.1", momentum=0.9, epsilon=1e-5),
tf.keras.layers.ReLU(name="text_projection.2"),
tf.keras.layers.Dense(self.projection_dim, name="text_projection.3"),
]

View File

@ -74,7 +74,7 @@ class TFRegNetConvLayer(tf.keras.layers.Layer):
use_bias=False,
name="convolution",
)
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
self.activation = ACT2FN[activation] if activation is not None else tf.identity
def call(self, hidden_state):
@ -126,7 +126,7 @@ class TFRegNetShortCut(tf.keras.layers.Layer):
self.convolution = tf.keras.layers.Conv2D(
filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
)
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
return self.normalization(self.convolution(inputs), training=training)

View File

@ -60,7 +60,7 @@ class TFResNetConvLayer(tf.keras.layers.Layer):
out_channels, kernel_size=kernel_size, strides=stride, padding="valid", use_bias=False, name="convolution"
)
# Use same default momentum and epsilon as PyTorch equivalent
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
self.activation = ACT2FN[activation] if activation is not None else tf.keras.layers.Activation("linear")
def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor:
@ -119,7 +119,7 @@ class TFResNetShortCut(tf.keras.layers.Layer):
out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
)
# Use same default momentum and epsilon as PyTorch equivalent
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = x

View File

@ -741,7 +741,7 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
self.linear_fuse = tf.keras.layers.Conv2D(
filters=config.decoder_hidden_size, kernel_size=1, use_bias=False, name="linear_fuse"
)
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="batch_norm")
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="batch_norm")
self.activation = tf.keras.layers.Activation("relu")
self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob)