load pretrained embeddings in Bert decoder

In Rothe et al.'s "Leveraging Pre-trained Checkpoints for Sequence
Generation Tasks", Bert2Bert is initialized with pre-trained weights for
the encoder, and only pre-trained embeddings for the decoder. The
current version of the code completely randomizes the weights of the
decoder.

We write a custom function to initiliaze the weights of the decoder; we
first initialize the decoder with the weights and then randomize
everything but the embeddings.
This commit is contained in:
Rémi Louf 2019-10-11 16:48:11 +02:00
parent 1e68c28670
commit f8e98d6779

View File

@ -1348,15 +1348,14 @@ class Bert2Rnd(BertPreTrainedModel):
self.encoder = BertModel(config) self.encoder = BertModel(config)
self.decoder = BertDecoderModel(config) self.decoder = BertDecoderModel(config)
self.init_weights()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs): def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
""" Load the pretrained weights in the encoder. """ Load the pretrained weights in the encoder.
Since the decoder needs to be initialized with random weights, and the encoder with The encoder of `Bert2Rand` is initialized with pretrained weights; the
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel` weights of the decoder are initialized at random except the embeddings
class. which are initialized with the pretrained embeddings. We thus need to override
the base class' `from_pretrained` method.
""" """
# Load the configuration # Load the configuration
@ -1374,10 +1373,26 @@ class Bert2Rnd(BertPreTrainedModel):
) )
model = cls(config) model = cls(config)
# The encoder is loaded with pretrained weights # We load the encoder with pretrained weights
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs) pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
model.encoder = pretrained_encoder model.encoder = pretrained_encoder
# We load the decoder with pretrained weights and then randomize all weights but embeddings-related one.
def randomize_decoder_weights(module):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
pretrained_decoder.apply(randomize_decoder_weights)
model.decoder = pretrained_decoder
return model return model
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
@ -1386,11 +1401,9 @@ class Bert2Rnd(BertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask)
encoder_output = encoder_outputs[0]
decoder_outputs = self.decoder(input_ids, decoder_outputs = self.decoder(input_ids,
encoder_output, encoder_outputs[0],
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask)
return decoder_outputs[0] return decoder_outputs