[T5] Fix init in TF and Flax for pretraining (#17294)

* fix init

* Apply suggestions from code review

* fix

* finish

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2022-05-18 15:08:56 +02:00 committed by GitHub
parent 7ba1d4e51f
commit 60ad73448c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 4 deletions

View File

@ -768,6 +768,8 @@ class T5PreTrainedModel(PreTrainedModel):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, T5DenseReluDense):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56

View File

@ -1112,7 +1112,9 @@ num_heads))`.
class TFT5Model(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
@ -1259,8 +1261,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model_dim = config.d_model
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
@ -1600,7 +1603,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
class TFT5EncoderModel(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: