Bug fix in examples;correct t_total for distributed training;run prediction for full dataset

This commit is contained in:
Li Li 2018-11-27 01:08:37 -08:00
parent 32167cdf4b
commit 0aaedcc02f
2 changed files with 42 additions and 22 deletions

View File

@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
@ -155,8 +156,8 @@ class MnliProcessor(DataProcessor):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[8])
text_b = line[9])
text_a = line[8]
text_b = line[9]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
@ -482,7 +483,7 @@ def main():
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
# Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list),
model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
if args.fp16:
model.half()
@ -507,10 +508,13 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)
t_total=t_total)
global_step = 0
if args.do_train:
@ -571,7 +575,7 @@ def main():
model.zero_grad()
global_step += 1
if args.do_eval:
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer)
@ -583,10 +587,8 @@ def main():
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()

View File

@ -25,6 +25,7 @@ import json
import math
import os
import random
import pickle
from tqdm import tqdm, trange
import numpy as np
@ -35,6 +36,7 @@ from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
@ -749,6 +751,10 @@ def main():
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--do_lower_case",
default=True,
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--local_rank",
type=int,
default=-1,
@ -845,20 +851,34 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)
t_total=t_total)
global_step = 0
if args.do_train:
train_features = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
args.bert_model, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
train_features = None
try:
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader)
except:
train_features = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
train_features = pickle.dump(train_features, writer)
logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features))
@ -913,7 +933,7 @@ def main():
model.zero_grad()
global_step += 1
if args.do_predict:
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_squad_examples(
input_file=args.predict_file, is_training=False)
eval_features = convert_examples_to_features(
@ -934,10 +954,8 @@ def main():
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
model.eval()