Fix LongT5ForConditionalGeneration initialization of lm_head (#28873)

This commit is contained in:
Eran Hirsch 2024-02-06 05:24:20 +02:00 committed by GitHub
parent 1ea0bbd73c
commit ee2a3400f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1301,6 +1301,8 @@ class LongT5PreTrainedModel(PreTrainedModel):
# Mesh TensorFlow embeddings initialization # Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, LongT5DenseActDense): elif isinstance(module, LongT5DenseActDense):
# Mesh TensorFlow FF initialization # Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56