load the pretrained weights for encoder-decoder

We currently save the pretrained_weights of the encoder and decoder in
two separate directories `encoder` and `decoder`. However, for the
`from_pretrained` function to operate with automodels we need to
specify the type of model in the path to the weights.

The path to the encoder/decoder weights is handled by the
`PreTrainedEncoderDecoder` class in the `save_pretrained` function. Sice
there is no easy way to infer the type of model that was initialized for
the encoder and decoder we add a parameter `model_type` to the function.
This is not an ideal solution as it is error prone, and the model type
should be carried by the Model classes somehow.

This is a temporary fix that should be changed before merging.
This commit is contained in:
Rémi Louf 2019-10-31 10:16:08 +01:00 committed by Julien Chaumond
parent 07f4cd73f6
commit 1c71ecc880
2 changed files with 49 additions and 30 deletions

View File

@ -328,6 +328,22 @@ def evaluate(args, model, tokenizer, prefix=""):
return result return result
def save_model_checkpoints(args, model, tokenizer):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir, model_type='bert')
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -454,36 +470,30 @@ def main():
# Train the model # Train the model
model.to(args.device) model.to(args.device)
if args.do_train: if args.do_train:
global_step, tr_loss = train(args, model, tokenizer) try:
global_step, tr_loss = train(args, model, tokenizer)
except KeyboardInterrupt:
response = input("You interrupted the training. Do you want to save the model checkpoints? [Y/n]")
if response.lower() in ["", "y", "yes"]:
save_model_checkpoints(args, model, tokenizer)
sys.exit(0)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
save_model_checkpoints(args, model, tokenizer)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
# Evaluate the model # Evaluate the model
results = {} results = {}
if args.do_evaluate: if args.do_evaluate:
checkpoints = [] checkpoints = [args.output_dir]
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "encoder") encoder_checkpoint = os.path.join(checkpoint, "bert_encoder")
decoder_checkpoint = os.path.join(checkpoint, "decoder") decoder_checkpoint = os.path.join(checkpoint, "bert_decoder")
model = PreTrainedEncoderDecoder.from_pretrained( model = PreTrainedEncoderDecoder.from_pretrained(
encoder_checkpoint, decoder_checkpoint encoder_checkpoint, decoder_checkpoint
) )
model.to(args.device) model.to(args.device)
results = "placeholder" print("model loaded")
return results return results

View File

@ -117,8 +117,7 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = { kwargs_common = {
argument: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("encoder_") if not argument.startswith("encoder_") and not argument.startswith("decoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = kwargs_common.copy() kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy()
@ -158,14 +157,27 @@ class PreTrainedEncoderDecoder(nn.Module):
return model return model
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory, model_type="bert"):
""" Save a Seq2Seq model and its configuration file in a format such """ Save an EncoderDecoder model and its configuration file in a format such
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained` that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
We save the encoder' and decoder's parameters in two separate directories. We save the encoder' and decoder's parameters in two separate directories.
If we want the weight loader to function we need to preprend the model
type to the directories' names. As far as I know there is no simple way
to infer the type of the model (except maybe by parsing the class'
names, which is not very future-proof). For now, we ask the user to
specify the model type explicitly when saving the weights.
""" """
self.encoder.save_pretrained(os.path.join(save_directory, "encoder")) encoder_path = os.path.join(save_directory, "{}_encoder".format(model_type))
self.decoder.save_pretrained(os.path.join(save_directory, "decoder")) if not os.path.exists(encoder_path):
os.makedirs(encoder_path)
self.encoder.save_pretrained(encoder_path)
decoder_path = os.path.join(save_directory, "{}_decoder".format(model_type))
if not os.path.exists(decoder_path):
os.makedirs(decoder_path)
self.decoder.save_pretrained(decoder_path)
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs): def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
""" The forward pass on a seq2eq depends what we are performing: """ The forward pass on a seq2eq depends what we are performing:
@ -193,8 +205,7 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = { kwargs_common = {
argument: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("encoder_") if not argument.startswith("encoder_") and not argument.startswith("decoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = kwargs_common.copy() kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy()
@ -217,9 +228,7 @@ class PreTrainedEncoderDecoder(nn.Module):
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[ encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
0
] # output the last layer hidden state
else: else:
encoder_outputs = () encoder_outputs = ()