mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[T5] Fix init in TF and Flax for pretraining (#17294)
* fix init * Apply suggestions from code review * fix * finish * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
7ba1d4e51f
commit
60ad73448c
@ -768,6 +768,8 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
# Mesh TensorFlow embeddings initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
|
||||
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
elif isinstance(module, T5DenseReluDense):
|
||||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
|
@ -1112,7 +1112,9 @@ num_heads))`.
|
||||
class TFT5Model(TFT5PreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
|
||||
self.shared = TFSharedEmbeddings(
|
||||
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
|
||||
)
|
||||
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
@ -1259,8 +1261,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
|
||||
self.shared = TFSharedEmbeddings(
|
||||
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
|
||||
)
|
||||
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
@ -1600,7 +1603,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
class TFT5EncoderModel(TFT5PreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
|
||||
self.shared = TFSharedEmbeddings(
|
||||
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
|
||||
)
|
||||
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
|
Loading…
Reference in New Issue
Block a user