TF: TFBart embedding initialization (#19460)

* correct embedding init
This commit is contained in:
Joao Gante 2022-10-11 14:44:46 +01:00 committed by GitHub
parent b651efe59e
commit 9ed80b0000
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 3 deletions

View File

@ -2059,12 +2059,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Return:
`tf.keras.layers.Embedding`: Resized Embedding layer.
"""
# Get the initialization range for the embeddings
init_range = 0.02 # default value
potential_initialization_variable_names = [
"initializer_range", # most common
"initializer_factor", # e.g. T5
"init_std", # e.g BART
]
for var_name in potential_initialization_variable_names:
if hasattr(self.config, var_name):
init_range = getattr(self.config, var_name)
# Get a new (initialized) embeddings layer
init_range = getattr(self.config, "initializer_range", 0.02)
new_embeddings = tf.keras.layers.Embedding(
input_dim=new_num_tokens,
output_dim=old_embeddings.output_dim,
embeddings_initializer=get_initializer(init_range),
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=init_range),
name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0"
)
new_embeddings(tf.constant([[0]]))

View File

@ -1053,7 +1053,12 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name="model.shared")
self.shared = tf.keras.layers.Embedding(
input_dim=config.vocab_size,
output_dim=config.d_model,
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
name="model.shared",
)
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix