mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
* Fix of issue #13327: Wrong weight initialization for TF t5 model * run black formatter * fix typo * remove my name tag from comments Co-authored-by: Shirron <dan.shirron@intel.com>
This commit is contained in:
parent
dec759e7e8
commit
2c8957feea
@ -93,8 +93,18 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
|
|||||||
class TFT5DenseReluDense(tf.keras.layers.Layer):
|
class TFT5DenseReluDense(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.wi = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi")
|
wi_initializer = tf.keras.initializers.RandomNormal(
|
||||||
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
|
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
|
||||||
|
)
|
||||||
|
wo_initializer = tf.keras.initializers.RandomNormal(
|
||||||
|
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
|
||||||
|
)
|
||||||
|
self.wi = tf.keras.layers.Dense(
|
||||||
|
config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
|
self.wo = tf.keras.layers.Dense(
|
||||||
|
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||||
self.act = tf.keras.activations.relu
|
self.act = tf.keras.activations.relu
|
||||||
|
|
||||||
@ -109,9 +119,21 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
|
|||||||
class TFT5GatedGeluDense(tf.keras.layers.Layer):
|
class TFT5GatedGeluDense(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.wi_0 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_0")
|
wi_initializer = tf.keras.initializers.RandomNormal(
|
||||||
self.wi_1 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_1")
|
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
|
||||||
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
|
)
|
||||||
|
wo_initializer = tf.keras.initializers.RandomNormal(
|
||||||
|
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
|
||||||
|
)
|
||||||
|
self.wi_0 = tf.keras.layers.Dense(
|
||||||
|
config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
|
self.wi_1 = tf.keras.layers.Dense(
|
||||||
|
config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
|
self.wo = tf.keras.layers.Dense(
|
||||||
|
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||||
self.act = get_tf_activation("gelu_new")
|
self.act = get_tf_activation("gelu_new")
|
||||||
|
|
||||||
@ -163,10 +185,34 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||||
|
|
||||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||||
self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q")
|
q_initializer = tf.keras.initializers.RandomNormal(
|
||||||
self.k = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="k")
|
mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
|
||||||
self.v = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="v")
|
)
|
||||||
self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
|
k_initializer = tf.keras.initializers.RandomNormal(
|
||||||
|
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
|
||||||
|
)
|
||||||
|
v_initializer = tf.keras.initializers.RandomNormal(
|
||||||
|
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
|
||||||
|
)
|
||||||
|
o_initializer = tf.keras.initializers.RandomNormal(
|
||||||
|
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
|
||||||
|
)
|
||||||
|
self.relative_attention_bias_initializer = tf.keras.initializers.RandomNormal(
|
||||||
|
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q = tf.keras.layers.Dense(
|
||||||
|
self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
|
self.k = tf.keras.layers.Dense(
|
||||||
|
self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
|
self.v = tf.keras.layers.Dense(
|
||||||
|
self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
|
self.o = tf.keras.layers.Dense(
|
||||||
|
self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
self.pruned_heads = set()
|
self.pruned_heads = set()
|
||||||
@ -177,6 +223,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
self.relative_attention_bias = self.add_weight(
|
self.relative_attention_bias = self.add_weight(
|
||||||
name="embeddings",
|
name="embeddings",
|
||||||
shape=[self.relative_attention_num_buckets, self.n_heads],
|
shape=[self.relative_attention_num_buckets, self.n_heads],
|
||||||
|
initializer=self.relative_attention_bias_initializer, # Add initializer
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().build(input_shape)
|
return super().build(input_shape)
|
||||||
@ -1266,7 +1313,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
|
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
|
||||||
|
|
||||||
if not config.tie_word_embeddings:
|
if not config.tie_word_embeddings:
|
||||||
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
|
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor)
|
||||||
|
self.lm_head = tf.keras.layers.Dense(
|
||||||
|
config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
@ -1280,7 +1330,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.set_input_embeddings(value)
|
self.set_input_embeddings(value)
|
||||||
else:
|
else:
|
||||||
self.lm_head = tf.keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head")
|
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor)
|
||||||
|
self.lm_head = tf.keras.layers.Dense(
|
||||||
|
shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer
|
||||||
|
) # Update init weights as in flax
|
||||||
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
|
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
|
||||||
# value has a shape (num_tokens, dim) then needs to be transposed
|
# value has a shape (num_tokens, dim) then needs to be transposed
|
||||||
transposed_value = tf.transpose(value)
|
transposed_value = tf.transpose(value)
|
||||||
|
Loading…
Reference in New Issue
Block a user