mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
load the pretrained weights for encoder-decoder
We currently save the pretrained_weights of the encoder and decoder in two separate directories `encoder` and `decoder`. However, for the `from_pretrained` function to operate with automodels we need to specify the type of model in the path to the weights. The path to the encoder/decoder weights is handled by the `PreTrainedEncoderDecoder` class in the `save_pretrained` function. Sice there is no easy way to infer the type of model that was initialized for the encoder and decoder we add a parameter `model_type` to the function. This is not an ideal solution as it is error prone, and the model type should be carried by the Model classes somehow. This is a temporary fix that should be changed before merging.
This commit is contained in:
parent
07f4cd73f6
commit
1c71ecc880
@ -328,6 +328,22 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def save_model_checkpoints(args, model, tokenizer):
|
||||||
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
|
|
||||||
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
|
# They can then be reloaded using `from_pretrained()`
|
||||||
|
model_to_save = (
|
||||||
|
model.module if hasattr(model, "module") else model
|
||||||
|
) # Take care of distributed/parallel training
|
||||||
|
model_to_save.save_pretrained(args.output_dir, model_type='bert')
|
||||||
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@ -454,36 +470,30 @@ def main():
|
|||||||
# Train the model
|
# Train the model
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
global_step, tr_loss = train(args, model, tokenizer)
|
try:
|
||||||
|
global_step, tr_loss = train(args, model, tokenizer)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
response = input("You interrupted the training. Do you want to save the model checkpoints? [Y/n]")
|
||||||
|
if response.lower() in ["", "y", "yes"]:
|
||||||
|
save_model_checkpoints(args, model, tokenizer)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
save_model_checkpoints(args, model, tokenizer)
|
||||||
if not os.path.exists(args.output_dir):
|
|
||||||
os.makedirs(args.output_dir)
|
|
||||||
|
|
||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
|
||||||
|
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
|
||||||
# They can then be reloaded using `from_pretrained()`
|
|
||||||
model_to_save = (
|
|
||||||
model.module if hasattr(model, "module") else model
|
|
||||||
) # Take care of distributed/parallel training
|
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
|
||||||
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
|
|
||||||
|
|
||||||
# Evaluate the model
|
# Evaluate the model
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_evaluate:
|
if args.do_evaluate:
|
||||||
checkpoints = []
|
checkpoints = [args.output_dir]
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
encoder_checkpoint = os.path.join(checkpoint, "encoder")
|
encoder_checkpoint = os.path.join(checkpoint, "bert_encoder")
|
||||||
decoder_checkpoint = os.path.join(checkpoint, "decoder")
|
decoder_checkpoint = os.path.join(checkpoint, "bert_decoder")
|
||||||
model = PreTrainedEncoderDecoder.from_pretrained(
|
model = PreTrainedEncoderDecoder.from_pretrained(
|
||||||
encoder_checkpoint, decoder_checkpoint
|
encoder_checkpoint, decoder_checkpoint
|
||||||
)
|
)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
results = "placeholder"
|
print("model loaded")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -117,8 +117,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
kwargs_common = {
|
kwargs_common = {
|
||||||
argument: value
|
argument: value
|
||||||
for argument, value in kwargs.items()
|
for argument, value in kwargs.items()
|
||||||
if not argument.startswith("encoder_")
|
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||||
and not argument.startswith("decoder_")
|
|
||||||
}
|
}
|
||||||
kwargs_decoder = kwargs_common.copy()
|
kwargs_decoder = kwargs_common.copy()
|
||||||
kwargs_encoder = kwargs_common.copy()
|
kwargs_encoder = kwargs_common.copy()
|
||||||
@ -158,14 +157,27 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory, model_type="bert"):
|
||||||
""" Save a Seq2Seq model and its configuration file in a format such
|
""" Save an EncoderDecoder model and its configuration file in a format such
|
||||||
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
|
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
|
||||||
|
|
||||||
We save the encoder' and decoder's parameters in two separate directories.
|
We save the encoder' and decoder's parameters in two separate directories.
|
||||||
|
|
||||||
|
If we want the weight loader to function we need to preprend the model
|
||||||
|
type to the directories' names. As far as I know there is no simple way
|
||||||
|
to infer the type of the model (except maybe by parsing the class'
|
||||||
|
names, which is not very future-proof). For now, we ask the user to
|
||||||
|
specify the model type explicitly when saving the weights.
|
||||||
"""
|
"""
|
||||||
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
|
encoder_path = os.path.join(save_directory, "{}_encoder".format(model_type))
|
||||||
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
|
if not os.path.exists(encoder_path):
|
||||||
|
os.makedirs(encoder_path)
|
||||||
|
self.encoder.save_pretrained(encoder_path)
|
||||||
|
|
||||||
|
decoder_path = os.path.join(save_directory, "{}_decoder".format(model_type))
|
||||||
|
if not os.path.exists(decoder_path):
|
||||||
|
os.makedirs(decoder_path)
|
||||||
|
self.decoder.save_pretrained(decoder_path)
|
||||||
|
|
||||||
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||||
""" The forward pass on a seq2eq depends what we are performing:
|
""" The forward pass on a seq2eq depends what we are performing:
|
||||||
@ -193,8 +205,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
kwargs_common = {
|
kwargs_common = {
|
||||||
argument: value
|
argument: value
|
||||||
for argument, value in kwargs.items()
|
for argument, value in kwargs.items()
|
||||||
if not argument.startswith("encoder_")
|
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||||
and not argument.startswith("decoder_")
|
|
||||||
}
|
}
|
||||||
kwargs_decoder = kwargs_common.copy()
|
kwargs_decoder = kwargs_common.copy()
|
||||||
kwargs_encoder = kwargs_common.copy()
|
kwargs_encoder = kwargs_common.copy()
|
||||||
@ -217,9 +228,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||||
encoder_hidden_states = encoder_outputs[
|
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
|
||||||
0
|
|
||||||
] # output the last layer hidden state
|
|
||||||
else:
|
else:
|
||||||
encoder_outputs = ()
|
encoder_outputs = ()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user