From 2c8957feea52e92d31189d428335f8e5f3088453 Mon Sep 17 00:00:00 2001 From: Dan Shirron Date: Wed, 3 Nov 2021 18:20:48 +0200 Subject: [PATCH] Fix of issue #13327: Wrong weight initialization for TF t5 model (#14241) * Fix of issue #13327: Wrong weight initialization for TF t5 model * run black formatter * fix typo * remove my name tag from comments Co-authored-by: Shirron --- src/transformers/models/t5/modeling_tf_t5.py | 75 +++++++++++++++++--- 1 file changed, 64 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index cc487d59d21..202e872b008 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -93,8 +93,18 @@ class TFT5LayerNorm(tf.keras.layers.Layer): class TFT5DenseReluDense(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.wi = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi") - self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo") + wi_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5) + ) + wo_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5) + ) + self.wi = tf.keras.layers.Dense( + config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wo = tf.keras.layers.Dense( + config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer + ) # Update init weights as in flax self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.act = tf.keras.activations.relu @@ -109,9 +119,21 @@ class TFT5DenseReluDense(tf.keras.layers.Layer): class TFT5GatedGeluDense(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.wi_0 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_0") - self.wi_1 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_1") - self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo") + wi_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5) + ) + wo_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5) + ) + self.wi_0 = tf.keras.layers.Dense( + config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wi_1 = tf.keras.layers.Dense( + config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wo = tf.keras.layers.Dense( + config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer + ) # Update init weights as in flax self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.act = get_tf_activation("gelu_new") @@ -163,10 +185,34 @@ class TFT5Attention(tf.keras.layers.Layer): self.inner_dim = self.n_heads * self.key_value_proj_dim # Mesh TensorFlow initialization to avoid scaling before softmax - self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q") - self.k = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="k") - self.v = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="v") - self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o") + q_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + ) + k_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5) + ) + v_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5) + ) + o_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5) + ) + self.relative_attention_bias_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5) + ) + + self.q = tf.keras.layers.Dense( + self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer + ) # Update init weights as in flax + self.k = tf.keras.layers.Dense( + self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer + ) # Update init weights as in flax + self.v = tf.keras.layers.Dense( + self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer + ) # Update init weights as in flax + self.o = tf.keras.layers.Dense( + self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer + ) # Update init weights as in flax self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.pruned_heads = set() @@ -177,6 +223,7 @@ class TFT5Attention(tf.keras.layers.Layer): self.relative_attention_bias = self.add_weight( name="embeddings", shape=[self.relative_attention_num_buckets, self.n_heads], + initializer=self.relative_attention_bias_initializer, # Add initializer ) return super().build(input_shape) @@ -1266,7 +1313,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") if not config.tie_word_embeddings: - self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head") + lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor) + self.lm_head = tf.keras.layers.Dense( + config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer + ) # Update init weights as in flax def get_output_embeddings(self): if self.config.tie_word_embeddings: @@ -1280,7 +1330,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling if self.config.tie_word_embeddings: self.set_input_embeddings(value) else: - self.lm_head = tf.keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head") + lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor) + self.lm_head = tf.keras.layers.Dense( + shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer + ) # Update init weights as in flax # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) # value has a shape (num_tokens, dim) then needs to be transposed transposed_value = tf.transpose(value)