working on squad

This commit is contained in:
thomwolf 2018-11-02 04:07:52 +01:00
parent e61db0d1c0
commit 8e81e5e6ff
2 changed files with 64 additions and 68 deletions

View File

@ -440,7 +440,7 @@ def main():
else:
device = torch.device("cuda", args.local_rank)
n_gpu = 1
print("Initializing the distributed backend: NCCL")
# print("Initializing the distributed backend: NCCL")
print("device", device, "n_gpu", n_gpu)
if not args.do_train and not args.do_eval:

View File

@ -30,6 +30,9 @@ import six
import tensorflow as tf
import argparse
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from modeling_pytorch import BertConfig, BertForQuestionAnswering
from optimization_pytorch import BERTAdam
@ -977,49 +980,13 @@ def main():
tokenizer = tokenization.FullTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
# tpu_cluster_resolver = None
# if args.use_tpu and args.tpu_name:
# tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
# args.tpu_name, zone=args.tpu_zone, project=args.gcp_project)
# is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
# run_config = tf.contrib.tpu.RunConfig(
# cluster=tpu_cluster_resolver,
# master=args.master,
# model_dir=args.output_dir,
# save_checkpoints_steps=args.save_checkpoints_steps,
# tpu_config=tf.contrib.tpu.TPUConfig(
# iterations_per_loop=args.iterations_per_loop,
# num_shards=args.num_tpu_cores,
# per_host_input_for_training=is_per_host))
train_examples = None
num_train_steps = None
# num_warmup_steps = None
if args.do_train:
train_examples = read_squad_examples(
input_file=args.train_file, is_training=True)
num_train_steps = int(
len(train_examples) / args.train_batch_size * args.num_train_epochs)
# num_warmup_steps = int(num_train_steps * args.warmup_proportion)
# model_fn = model_fn_builder(
# bert_config=bert_config,
# init_checkpoint=args.init_checkpoint,
# learning_rate=args.learning_rate,
# num_train_steps=num_train_steps,
# num_warmup_steps=num_warmup_steps,
# use_tpu=args.use_tpu,
# use_one_hot_embeddings=args.use_tpu)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
# estimator = tf.contrib.tpu.TPUEstimator(
# use_tpu=args.use_tpu,
# model_fn=model_fn,
# config=run_config,
# train_batch_size=args.train_batch_size,
# predict_batch_size=args.predict_batch_size)
model = BertForQuestionAnswering(bert_config)
if args.init_checkpoint is not None:
@ -1041,17 +1008,36 @@ def main():
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
tf.logging.info("***** Running training *****")
tf.logging.info(" Num orig examples = %d", len(train_examples))
tf.logging.info(" Num split examples = %d", len(train_features))
tf.logging.info(" Batch size = %d", args.train_batch_size)
tf.logging.info(" Num steps = %d", num_train_steps)
train_input_fn = input_fn_builder(
features=train_features,
seq_length=args.max_seq_length,
is_training=True,
drop_remainder=True)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train()
for epoch in args.num_train_epochs:
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss.backward()
optimizer.step()
global_step += 1
if args.do_predict:
eval_examples = read_squad_examples(
@ -1064,29 +1050,39 @@ def main():
max_query_length=args.max_query_length,
is_training=False)
tf.logging.info("***** Running predictions *****")
tf.logging.info(" Num orig examples = %d", len(eval_examples))
tf.logging.info(" Num split examples = %d", len(eval_features))
tf.logging.info(" Batch size = %d", args.predict_batch_size)
logger.info("***** Running predictions *****")
logger.info(" Num orig examples = %d", len(eval_examples))
logger.info(" Num split examples = %d", len(eval_features))
logger.info(" Batch size = %d", args.predict_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
all_results = []
predict_input_fn = input_fn_builder(
features=eval_features,
seq_length=args.max_seq_length,
is_training=False,
drop_remainder=False)
# If running eval on the TPU, you will need to specify the number of
# steps.
all_results = []
for result in estimator.predict(
predict_input_fn, yield_single_examples=True):
for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
if len(all_results) % 1000 == 0:
tf.logging.info("Processing example: %d" % (len(all_results)))
unique_id = int(result["unique_ids"])
start_logits = [float(x) for x in result["start_logits"].flat]
end_logits = [float(x) for x in result["end_logits"].flat]
logger.info("Processing example: %d" % (len(all_results)))
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
unique_id = [int(eval_features[e.item()].unique_id) for e in example_index]
start_logits = [x.item() for x in start_logits]
end_logits = [x.item() for x in end_logits]
all_results.append(
RawResult(
unique_id=unique_id,