mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
add saving and loading model in examples
This commit is contained in:
parent
93f335ef86
commit
d3fcec1a3e
@ -546,6 +546,15 @@ def main():
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
# Save a trained model
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||||
|
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||||
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
|
||||||
|
# Load a trained model that you have fine-tuned
|
||||||
|
model_state_dict = torch.load(output_model_file)
|
||||||
|
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||||
|
|
||||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||||
eval_features = convert_examples_to_features(
|
eval_features = convert_examples_to_features(
|
||||||
@ -593,10 +602,6 @@ def main():
|
|||||||
'global_step': global_step,
|
'global_step': global_step,
|
||||||
'loss': tr_loss/nb_tr_steps}
|
'loss': tr_loss/nb_tr_steps}
|
||||||
|
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model
|
|
||||||
raise NotImplementedError # TODO add save of the configuration file and vocabulary file also ?
|
|
||||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
|
||||||
torch.save(model_to_save, output_model_file)
|
|
||||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
logger.info("***** Eval results *****")
|
logger.info("***** Eval results *****")
|
||||||
|
@ -911,6 +911,15 @@ def main():
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
# Save a trained model
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||||
|
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||||
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
|
||||||
|
# Load a trained model that you have fine-tuned
|
||||||
|
model_state_dict = torch.load(output_model_file)
|
||||||
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||||
|
|
||||||
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
eval_examples = read_squad_examples(
|
eval_examples = read_squad_examples(
|
||||||
input_file=args.predict_file, is_training=False)
|
input_file=args.predict_file, is_training=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user