mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
clarify and unify model saving logic in examples
This commit is contained in:
parent
81c7e3ec9f
commit
eebc8abbe2
@ -779,7 +779,8 @@ python run_classifier.py \
|
||||
--train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/mrpc_output/
|
||||
--output_dir /tmp/mrpc_output/ \
|
||||
--fp16
|
||||
```
|
||||
|
||||
#### SQuAD
|
||||
|
@ -23,7 +23,6 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -33,7 +32,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
|
||||
@ -92,7 +91,7 @@ class DataProcessor(object):
|
||||
@classmethod
|
||||
def _read_tsv(cls, input_file, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
with open(input_file, "rb") as f:
|
||||
with open(input_file, "r") as f:
|
||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||
lines = []
|
||||
for line in reader:
|
||||
@ -324,6 +323,10 @@ def main():
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
@ -383,9 +386,17 @@ def main():
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
processors = {
|
||||
"cola": ColaProcessor,
|
||||
"mnli": MnliProcessor,
|
||||
@ -451,8 +462,9 @@ def main():
|
||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
||||
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
|
||||
cache_dir=cache_dir,
|
||||
num_labels = num_labels)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
@ -549,15 +561,21 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
if args.do_train:
|
||||
# Save a trained model and the associated configuration
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
with open(output_config_file, 'w') as f:
|
||||
f.write(model_to_save.config.to_json_string())
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels)
|
||||
# Load a trained model and config that you have fine-tuned
|
||||
config = BertConfig(output_config_file)
|
||||
model = BertForSequenceClassification(config, num_labels=num_labels)
|
||||
model.load_state_dict(torch.load(output_model_file))
|
||||
else:
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||
model.to(device)
|
||||
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
|
@ -35,7 +35,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
||||
BertTokenizer,
|
||||
@ -1001,14 +1001,19 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
if args.do_train:
|
||||
# Save a trained model and the associated configuration
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
with open(output_config_file, 'w') as f:
|
||||
f.write(model_to_save.config.to_json_string())
|
||||
|
||||
# Load a trained model and config that you have fine-tuned
|
||||
config = BertConfig(output_config_file)
|
||||
model = BertForQuestionAnswering(config)
|
||||
model.load_state_dict(torch.load(output_model_file))
|
||||
else:
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||
|
||||
|
@ -469,18 +469,25 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||
state_dict=model_state_dict,
|
||||
num_choices=4)
|
||||
if args.do_train:
|
||||
# Save a trained model and the associated configuration
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
with open(output_config_file, 'w') as f:
|
||||
f.write(model_to_save.config.to_json_string())
|
||||
|
||||
# Load a trained model and config that you have fine-tuned
|
||||
config = BertConfig(output_config_file)
|
||||
model = BertForMultipleChoice(config, num_choices=4)
|
||||
model.load_state_dict(torch.load(output_model_file))
|
||||
else:
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
|
||||
model.to(device)
|
||||
|
||||
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
|
||||
eval_features = convert_examples_to_features(
|
||||
|
Loading…
Reference in New Issue
Block a user