Fix of issue #13327: Wrong weight initialization for TF t5 model (#14241)

* Fix of issue #13327: Wrong weight initialization for TF t5 model

* run black formatter

* fix typo

* remove my name tag from comments

Co-authored-by: Shirron <dan.shirron@intel.com>
This commit is contained in:
Dan Shirron 2021-11-03 18:20:48 +02:00 committed by GitHub
parent dec759e7e8
commit 2c8957feea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -93,8 +93,18 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
class TFT5DenseReluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.wi = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi")
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
wi_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
)
wo_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
)
self.wi = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer
) # Update init weights as in flax
self.wo = tf.keras.layers.Dense(
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = tf.keras.activations.relu
@ -109,9 +119,21 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
class TFT5GatedGeluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.wi_0 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_0")
self.wi_1 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_1")
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
wi_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
)
wo_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
)
self.wi_0 = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer
) # Update init weights as in flax
self.wi_1 = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer
) # Update init weights as in flax
self.wo = tf.keras.layers.Dense(
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = get_tf_activation("gelu_new")
@ -163,10 +185,34 @@ class TFT5Attention(tf.keras.layers.Layer):
self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q")
self.k = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="k")
self.v = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="v")
self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
q_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
)
k_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
v_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
o_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
self.relative_attention_bias_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
self.q = tf.keras.layers.Dense(
self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer
) # Update init weights as in flax
self.k = tf.keras.layers.Dense(
self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer
) # Update init weights as in flax
self.v = tf.keras.layers.Dense(
self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer
) # Update init weights as in flax
self.o = tf.keras.layers.Dense(
self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.pruned_heads = set()
@ -177,6 +223,7 @@ class TFT5Attention(tf.keras.layers.Layer):
self.relative_attention_bias = self.add_weight(
name="embeddings",
shape=[self.relative_attention_num_buckets, self.n_heads],
initializer=self.relative_attention_bias_initializer, # Add initializer
)
return super().build(input_shape)
@ -1266,7 +1313,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
if not config.tie_word_embeddings:
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor)
self.lm_head = tf.keras.layers.Dense(
config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer
) # Update init weights as in flax
def get_output_embeddings(self):
if self.config.tie_word_embeddings:
@ -1280,7 +1330,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
if self.config.tie_word_embeddings:
self.set_input_embeddings(value)
else:
self.lm_head = tf.keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head")
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor)
self.lm_head = tf.keras.layers.Dense(
shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer
) # Update init weights as in flax
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
# value has a shape (num_tokens, dim) then needs to be transposed
transposed_value = tf.transpose(value)