diff --git a/examples/run_squad2.py b/examples/run_squad2.py index fd35beef1e2..558b24764e8 100644 --- a/examples/run_squad2.py +++ b/examples/run_squad2.py @@ -1010,7 +1010,8 @@ def main(): # Save a trained model model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") - torch.save(model_to_save.state_dict(), output_model_file) + if args.do_train: + torch.save(model_to_save.state_dict(), output_model_file) # Load a trained model that you have fine-tuned model_state_dict = torch.load(output_model_file)