clarify and unify model saving logic in examples

This commit is contained in:
thomwolf 2019-02-11 14:04:19 +01:00
parent 81c7e3ec9f
commit eebc8abbe2
4 changed files with 59 additions and 28 deletions

View File

@ -779,7 +779,8 @@ python run_classifier.py \
--train_batch_size 32 \ --train_batch_size 32 \
--learning_rate 2e-5 \ --learning_rate 2e-5 \
--num_train_epochs 3.0 \ --num_train_epochs 3.0 \
--output_dir /tmp/mrpc_output/ --output_dir /tmp/mrpc_output/ \
--fp16
``` ```
#### SQuAD #### SQuAD

View File

@ -23,7 +23,6 @@ import logging
import os import os
import random import random
import sys import sys
from io import open
import numpy as np import numpy as np
import torch import torch
@ -33,7 +32,7 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
@ -92,7 +91,7 @@ class DataProcessor(object):
@classmethod @classmethod
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "rb") as f: with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = [] lines = []
for line in reader: for line in reader:
@ -324,6 +323,10 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.") help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters ## Other parameters
parser.add_argument("--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length", parser.add_argument("--max_seq_length",
default=128, default=128,
type=int, type=int,
@ -383,9 +386,17 @@ def main():
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n" "0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n") "Positive power of 2: static loss scaling value.\n")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
processors = { processors = {
"cola": ColaProcessor, "cola": ColaProcessor,
"mnli": MnliProcessor, "mnli": MnliProcessor,
@ -451,8 +462,9 @@ def main():
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
model = BertForSequenceClassification.from_pretrained(args.bert_model, model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)), cache_dir=cache_dir,
num_labels = num_labels) num_labels = num_labels)
if args.fp16: if args.fp16:
model.half() model.half()
@ -549,15 +561,21 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train: if args.do_train:
# Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
with open(output_config_file, 'w') as f:
f.write(model_to_save.config.to_json_string())
# Load a trained model that you have fine-tuned # Load a trained model and config that you have fine-tuned
model_state_dict = torch.load(output_model_file) config = BertConfig(output_config_file)
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels) model = BertForSequenceClassification(config, num_labels=num_labels)
model.load_state_dict(torch.load(output_model_file))
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
model.to(device) model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):

View File

@ -35,7 +35,7 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import (BasicTokenizer, from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer, BertTokenizer,
@ -1001,14 +1001,19 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train: if args.do_train:
# Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
model_state_dict = torch.load(output_model_file) with open(output_config_file, 'w') as f:
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) f.write(model_to_save.config.to_json_string())
# Load a trained model and config that you have fine-tuned
config = BertConfig(output_config_file)
model = BertForQuestionAnswering(config)
model.load_state_dict(torch.load(output_model_file))
else: else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model) model = BertForQuestionAnswering.from_pretrained(args.bert_model)

View File

@ -469,18 +469,25 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned if args.do_train:
model_state_dict = torch.load(output_model_file) # Save a trained model and the associated configuration
model = BertForMultipleChoice.from_pretrained(args.bert_model, model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
state_dict=model_state_dict, output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
num_choices=4) torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
with open(output_config_file, 'w') as f:
f.write(model_to_save.config.to_json_string())
# Load a trained model and config that you have fine-tuned
config = BertConfig(output_config_file)
model = BertForMultipleChoice(config, num_choices=4)
model.load_state_dict(torch.load(output_model_file))
else:
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
model.to(device) model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True) eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(