mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
load pretrained embeddings in Bert decoder
In Rothe et al.'s "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", Bert2Bert is initialized with pre-trained weights for the encoder, and only pre-trained embeddings for the decoder. The current version of the code completely randomizes the weights of the decoder. We write a custom function to initiliaze the weights of the decoder; we first initialize the decoder with the weights and then randomize everything but the embeddings.
This commit is contained in:
parent
1e68c28670
commit
f8e98d6779
@ -1348,15 +1348,14 @@ class Bert2Rnd(BertPreTrainedModel):
|
||||
self.encoder = BertModel(config)
|
||||
self.decoder = BertDecoderModel(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
|
||||
""" Load the pretrained weights in the encoder.
|
||||
|
||||
Since the decoder needs to be initialized with random weights, and the encoder with
|
||||
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
|
||||
class.
|
||||
The encoder of `Bert2Rand` is initialized with pretrained weights; the
|
||||
weights of the decoder are initialized at random except the embeddings
|
||||
which are initialized with the pretrained embeddings. We thus need to override
|
||||
the base class' `from_pretrained` method.
|
||||
"""
|
||||
|
||||
# Load the configuration
|
||||
@ -1374,10 +1373,26 @@ class Bert2Rnd(BertPreTrainedModel):
|
||||
)
|
||||
model = cls(config)
|
||||
|
||||
# The encoder is loaded with pretrained weights
|
||||
# We load the encoder with pretrained weights
|
||||
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
model.encoder = pretrained_encoder
|
||||
|
||||
# We load the decoder with pretrained weights and then randomize all weights but embeddings-related one.
|
||||
def randomize_decoder_weights(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BertLayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
pretrained_decoder.apply(randomize_decoder_weights)
|
||||
model.decoder = pretrained_decoder
|
||||
|
||||
return model
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
@ -1386,11 +1401,9 @@ class Bert2Rnd(BertPreTrainedModel):
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
encoder_output = encoder_outputs[0]
|
||||
|
||||
decoder_outputs = self.decoder(input_ids,
|
||||
encoder_output,
|
||||
encoder_outputs[0],
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
return decoder_outputs[0]
|
||||
return decoder_outputs
|
||||
|
Loading…
Reference in New Issue
Block a user