mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
b651efe59e
commit
9ed80b0000
@ -2059,12 +2059,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
Return:
|
||||
`tf.keras.layers.Embedding`: Resized Embedding layer.
|
||||
"""
|
||||
|
||||
# Get the initialization range for the embeddings
|
||||
init_range = 0.02 # default value
|
||||
potential_initialization_variable_names = [
|
||||
"initializer_range", # most common
|
||||
"initializer_factor", # e.g. T5
|
||||
"init_std", # e.g BART
|
||||
]
|
||||
for var_name in potential_initialization_variable_names:
|
||||
if hasattr(self.config, var_name):
|
||||
init_range = getattr(self.config, var_name)
|
||||
|
||||
# Get a new (initialized) embeddings layer
|
||||
init_range = getattr(self.config, "initializer_range", 0.02)
|
||||
new_embeddings = tf.keras.layers.Embedding(
|
||||
input_dim=new_num_tokens,
|
||||
output_dim=old_embeddings.output_dim,
|
||||
embeddings_initializer=get_initializer(init_range),
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=init_range),
|
||||
name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0"
|
||||
)
|
||||
new_embeddings(tf.constant([[0]]))
|
||||
|
@ -1053,7 +1053,12 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name="model.shared")
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.d_model,
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
|
||||
name="model.shared",
|
||||
)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user