mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
* Fix of issue #13327: Wrong weight initialization for TF t5 model * run black formatter * fix typo * remove my name tag from comments Co-authored-by: Shirron <dan.shirron@intel.com>
This commit is contained in:
parent
dec759e7e8
commit
2c8957feea
@ -93,8 +93,18 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
|
||||
class TFT5DenseReluDense(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.wi = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi")
|
||||
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
|
||||
wi_initializer = tf.keras.initializers.RandomNormal(
|
||||
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
|
||||
)
|
||||
wo_initializer = tf.keras.initializers.RandomNormal(
|
||||
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
|
||||
)
|
||||
self.wi = tf.keras.layers.Dense(
|
||||
config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer
|
||||
) # Update init weights as in flax
|
||||
self.wo = tf.keras.layers.Dense(
|
||||
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
|
||||
) # Update init weights as in flax
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
self.act = tf.keras.activations.relu
|
||||
|
||||
@ -109,9 +119,21 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
|
||||
class TFT5GatedGeluDense(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.wi_0 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_0")
|
||||
self.wi_1 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_1")
|
||||
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
|
||||
wi_initializer = tf.keras.initializers.RandomNormal(
|
||||
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
|
||||
)
|
||||
wo_initializer = tf.keras.initializers.RandomNormal(
|
||||
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
|
||||
)
|
||||
self.wi_0 = tf.keras.layers.Dense(
|
||||
config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer
|
||||
) # Update init weights as in flax
|
||||
self.wi_1 = tf.keras.layers.Dense(
|
||||
config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer
|
||||
) # Update init weights as in flax
|
||||
self.wo = tf.keras.layers.Dense(
|
||||
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
|
||||
) # Update init weights as in flax
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
self.act = get_tf_activation("gelu_new")
|
||||
|
||||
@ -163,10 +185,34 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q")
|
||||
self.k = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="k")
|
||||
self.v = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="v")
|
||||
self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
|
||||
q_initializer = tf.keras.initializers.RandomNormal(
|
||||
mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
|
||||
)
|
||||
k_initializer = tf.keras.initializers.RandomNormal(
|
||||
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
|
||||
)
|
||||
v_initializer = tf.keras.initializers.RandomNormal(
|
||||
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
|
||||
)
|
||||
o_initializer = tf.keras.initializers.RandomNormal(
|
||||
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
|
||||
)
|
||||
self.relative_attention_bias_initializer = tf.keras.initializers.RandomNormal(
|
||||
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
|
||||
)
|
||||
|
||||
self.q = tf.keras.layers.Dense(
|
||||
self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer
|
||||
) # Update init weights as in flax
|
||||
self.k = tf.keras.layers.Dense(
|
||||
self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer
|
||||
) # Update init weights as in flax
|
||||
self.v = tf.keras.layers.Dense(
|
||||
self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer
|
||||
) # Update init weights as in flax
|
||||
self.o = tf.keras.layers.Dense(
|
||||
self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer
|
||||
) # Update init weights as in flax
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
self.pruned_heads = set()
|
||||
@ -177,6 +223,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
self.relative_attention_bias = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.relative_attention_num_buckets, self.n_heads],
|
||||
initializer=self.relative_attention_bias_initializer, # Add initializer
|
||||
)
|
||||
|
||||
return super().build(input_shape)
|
||||
@ -1266,7 +1313,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
|
||||
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor)
|
||||
self.lm_head = tf.keras.layers.Dense(
|
||||
config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer
|
||||
) # Update init weights as in flax
|
||||
|
||||
def get_output_embeddings(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
@ -1280,7 +1330,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
if self.config.tie_word_embeddings:
|
||||
self.set_input_embeddings(value)
|
||||
else:
|
||||
self.lm_head = tf.keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head")
|
||||
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor)
|
||||
self.lm_head = tf.keras.layers.Dense(
|
||||
shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer
|
||||
) # Update init weights as in flax
|
||||
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
|
||||
# value has a shape (num_tokens, dim) then needs to be transposed
|
||||
transposed_value = tf.transpose(value)
|
||||
|
Loading…
Reference in New Issue
Block a user