load pretrained embeddings in Bert decoder

In Rothe et al.'s "Leveraging Pre-trained Checkpoints for Sequence
Generation Tasks", Bert2Bert is initialized with pre-trained weights for
the encoder, and only pre-trained embeddings for the decoder. The
current version of the code completely randomizes the weights of the
decoder.

We write a custom function to initiliaze the weights of the decoder; we
first initialize the decoder with the weights and then randomize
everything but the embeddings.
This commit is contained in:
Rémi Louf 2019-10-11 16:48:11 +02:00
parent 1e68c28670
commit f8e98d6779

View File

@ -1348,15 +1348,14 @@ class Bert2Rnd(BertPreTrainedModel):
self.encoder = BertModel(config)
self.decoder = BertDecoderModel(config)
self.init_weights()
@classmethod
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
""" Load the pretrained weights in the encoder.
Since the decoder needs to be initialized with random weights, and the encoder with
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
class.
The encoder of `Bert2Rand` is initialized with pretrained weights; the
weights of the decoder are initialized at random except the embeddings
which are initialized with the pretrained embeddings. We thus need to override
the base class' `from_pretrained` method.
"""
# Load the configuration
@ -1374,10 +1373,26 @@ class Bert2Rnd(BertPreTrainedModel):
)
model = cls(config)
# The encoder is loaded with pretrained weights
# We load the encoder with pretrained weights
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
model.encoder = pretrained_encoder
# We load the decoder with pretrained weights and then randomize all weights but embeddings-related one.
def randomize_decoder_weights(module):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
pretrained_decoder.apply(randomize_decoder_weights)
model.decoder = pretrained_decoder
return model
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
@ -1386,11 +1401,9 @@ class Bert2Rnd(BertPreTrainedModel):
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
encoder_output = encoder_outputs[0]
decoder_outputs = self.decoder(input_ids,
encoder_output,
encoder_outputs[0],
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
return decoder_outputs[0]
return decoder_outputs