mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add saving and loading model in examples
This commit is contained in:
parent
93f335ef86
commit
d3fcec1a3e
@ -546,6 +546,15 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
eval_features = convert_examples_to_features(
|
||||
@ -593,10 +602,6 @@ def main():
|
||||
'global_step': global_step,
|
||||
'loss': tr_loss/nb_tr_steps}
|
||||
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
raise NotImplementedError # TODO add save of the configuration file and vocabulary file also ?
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
torch.save(model_to_save, output_model_file)
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
|
@ -911,6 +911,15 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||
|
||||
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = read_squad_examples(
|
||||
input_file=args.predict_file, is_training=False)
|
||||
|
Loading…
Reference in New Issue
Block a user