mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Fix LongT5ForConditionalGeneration initialization of lm_head (#28873)
This commit is contained in:
parent
1ea0bbd73c
commit
ee2a3400f2
@ -1301,6 +1301,8 @@ class LongT5PreTrainedModel(PreTrainedModel):
|
||||
# Mesh TensorFlow embeddings initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
|
||||
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
||||
elif isinstance(module, LongT5DenseActDense):
|
||||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
|
Loading…
Reference in New Issue
Block a user