diff --git a/examples/run_classifier.py b/examples/run_classifier.py index f18c5489bae..7c4eb7da47f 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -546,6 +546,15 @@ def main(): optimizer.zero_grad() global_step += 1 + # Save a trained model + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") + torch.save(model_to_save.state_dict(), output_model_file) + + # Load a trained model that you have fine-tuned + model_state_dict = torch.load(output_model_file) + model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict) + if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features( @@ -593,10 +602,6 @@ def main(): 'global_step': global_step, 'loss': tr_loss/nb_tr_steps} - model_to_save = model.module if hasattr(model, 'module') else model - raise NotImplementedError # TODO add save of the configuration file and vocabulary file also ? - output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") - torch.save(model_to_save, output_model_file) output_eval_file = os.path.join(args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") diff --git a/examples/run_squad.py b/examples/run_squad.py index b0668b38d8c..81956ad394c 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -911,6 +911,15 @@ def main(): optimizer.zero_grad() global_step += 1 + # Save a trained model + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") + torch.save(model_to_save.state_dict(), output_model_file) + + # Load a trained model that you have fine-tuned + model_state_dict = torch.load(output_model_file) + model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) + if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = read_squad_examples( input_file=args.predict_file, is_training=False)