mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
Cleanup
Improve global visibility on the run_squad script, remove unused files and fixes related to XLNet.
This commit is contained in:
parent
9ecd83dace
commit
e9217da5ff
@ -27,8 +27,7 @@ import glob
|
||||
import timeit
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
@ -48,14 +47,6 @@ from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
|
||||
|
||||
from utils_squad import (convert_examples_to_features as old_convert, read_squad_examples as old_read, RawResult, write_predictions,
|
||||
RawResultExtended, write_predictions_extended)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
# You can remove it from the dependencies if you are using this script outside of the library
|
||||
# We've added it here for automated tests (see examples/test_examples.py file)
|
||||
from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||
@ -98,14 +89,16 @@ def train(args, train_dataset, model, tokenizer):
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
@ -133,20 +126,26 @@ def train(args, train_dataset, model, tokenizer):
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]}
|
||||
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]
|
||||
}
|
||||
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5],
|
||||
'p_mask': batch[6]})
|
||||
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
@ -173,8 +172,8 @@ def train(args, train_dataset, model, tokenizer):
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Log metrics
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
@ -183,8 +182,8 @@ def train(args, train_dataset, model, tokenizer):
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
# Save model checkpoint
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
@ -213,6 +212,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
|
||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
@ -225,11 +225,14 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
|
||||
all_results = []
|
||||
start_time = timeit.default_timer()
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
@ -238,10 +241,13 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
|
||||
example_indices = batch[3]
|
||||
|
||||
# XLNet and XLM use more arguments for their predictions
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[4],
|
||||
'p_mask': batch[5]})
|
||||
inputs.update({'cls_index': batch[4], 'p_mask': batch[5]})
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
@ -250,11 +256,13 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
|
||||
output = [to_list(output[i]) for output in outputs]
|
||||
|
||||
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
|
||||
# models only use two.
|
||||
if len(output) >= 5:
|
||||
start_logits = output[0]
|
||||
start_top_index = output[1]
|
||||
end_logits = output[2]
|
||||
end_top_index = output[3],
|
||||
end_top_index = output[3]
|
||||
cls_logits = output[4]
|
||||
|
||||
result = SquadResult(
|
||||
@ -278,16 +286,17 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
|
||||
if args.version_2_with_negative:
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
else:
|
||||
output_null_log_odds_file = None
|
||||
|
||||
# XLNet and XLM use a more complex post-processing procedure
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||
output_nbest_file, output_null_log_odds_file,
|
||||
model.config.start_n_top, model.config.end_n_top,
|
||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||
else:
|
||||
@ -296,6 +305,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
|
||||
# Compute the F1 and exact scores.
|
||||
results = squad_evaluate(examples, predictions)
|
||||
return results
|
||||
|
||||
@ -308,7 +318,10 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
cached_features_file = os.path.join(input_dir, 'cached_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length)))
|
||||
str(args.max_seq_length))
|
||||
)
|
||||
|
||||
# Init features and dataset from cache if it exists
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features_and_dataset = torch.load(cached_features_file)
|
||||
@ -341,7 +354,6 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
return_dataset='pt'
|
||||
)
|
||||
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save({"features": features, "dataset": dataset}, cached_features_file)
|
||||
@ -452,6 +464,11 @@ def main():
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.predict_file = os.path.join(args.output_dir, 'predictions_{}_{}.txt'.format(
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length))
|
||||
)
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,330 +0,0 @@
|
||||
""" Official evaluation script for SQuAD version 2.0.
|
||||
Modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0
|
||||
|
||||
In addition to basic functionality, we also compute additional statistics and
|
||||
plot precision-recall curves if an additional na_prob.json file is provided.
|
||||
This file is expected to map question ID's to the model's predicted probability
|
||||
that a question is unanswerable.
|
||||
"""
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
|
||||
class EVAL_OPTS():
|
||||
def __init__(self, data_file, pred_file, out_file="",
|
||||
na_prob_file="na_prob.json", na_prob_thresh=1.0,
|
||||
out_image_dir=None, verbose=False):
|
||||
self.data_file = data_file
|
||||
self.pred_file = pred_file
|
||||
self.out_file = out_file
|
||||
self.na_prob_file = na_prob_file
|
||||
self.na_prob_thresh = na_prob_thresh
|
||||
self.out_image_dir = out_image_dir
|
||||
self.verbose = verbose
|
||||
|
||||
OPTS = None
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
|
||||
parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
|
||||
parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
|
||||
parser.add_argument('--out-file', '-o', metavar='eval.json',
|
||||
help='Write accuracy metrics to file (default is stdout).')
|
||||
parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
|
||||
help='Model estimates of probability of no answer.')
|
||||
parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
|
||||
help='Predict "" if no-answer probability exceeds this (default = 1.0).')
|
||||
parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
|
||||
help='Save precision-recall curves to directory.')
|
||||
parser.add_argument('--verbose', '-v', action='store_true')
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
return parser.parse_args()
|
||||
|
||||
def make_qid_to_has_ans(dataset):
|
||||
qid_to_has_ans = {}
|
||||
for article in dataset:
|
||||
for p in article['paragraphs']:
|
||||
for qa in p['qas']:
|
||||
qid_to_has_ans[qa['id']] = bool(qa['answers'])
|
||||
return qid_to_has_ans
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
def remove_articles(text):
|
||||
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
||||
return re.sub(regex, ' ', text)
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def get_tokens(s):
|
||||
if not s: return []
|
||||
return normalize_answer(s).split()
|
||||
|
||||
def compute_exact(a_gold, a_pred):
|
||||
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
||||
|
||||
def compute_f1(a_gold, a_pred):
|
||||
gold_toks = get_tokens(a_gold)
|
||||
pred_toks = get_tokens(a_pred)
|
||||
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
||||
num_same = sum(common.values())
|
||||
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
||||
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
||||
return int(gold_toks == pred_toks)
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(pred_toks)
|
||||
recall = 1.0 * num_same / len(gold_toks)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
def get_raw_scores(dataset, preds):
|
||||
exact_scores = {}
|
||||
f1_scores = {}
|
||||
for article in dataset:
|
||||
for p in article['paragraphs']:
|
||||
for qa in p['qas']:
|
||||
qid = qa['id']
|
||||
gold_answers = [a['text'] for a in qa['answers']
|
||||
if normalize_answer(a['text'])]
|
||||
if not gold_answers:
|
||||
# For unanswerable questions, only correct answer is empty string
|
||||
gold_answers = ['']
|
||||
if qid not in preds:
|
||||
print('Missing prediction for %s' % qid)
|
||||
continue
|
||||
a_pred = preds[qid]
|
||||
# Take max over all gold answers
|
||||
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
|
||||
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
|
||||
return exact_scores, f1_scores
|
||||
|
||||
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
||||
new_scores = {}
|
||||
for qid, s in scores.items():
|
||||
pred_na = na_probs[qid] > na_prob_thresh
|
||||
if pred_na:
|
||||
new_scores[qid] = float(not qid_to_has_ans[qid])
|
||||
else:
|
||||
new_scores[qid] = s
|
||||
return new_scores
|
||||
|
||||
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
||||
if not qid_list:
|
||||
total = len(exact_scores)
|
||||
return collections.OrderedDict([
|
||||
('exact', 100.0 * sum(exact_scores.values()) / total),
|
||||
('f1', 100.0 * sum(f1_scores.values()) / total),
|
||||
('total', total),
|
||||
])
|
||||
else:
|
||||
total = len(qid_list)
|
||||
return collections.OrderedDict([
|
||||
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
||||
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
||||
('total', total),
|
||||
])
|
||||
|
||||
def merge_eval(main_eval, new_eval, prefix):
|
||||
for k in new_eval:
|
||||
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
||||
|
||||
def plot_pr_curve(precisions, recalls, out_image, title):
|
||||
plt.step(recalls, precisions, color='b', alpha=0.2, where='post')
|
||||
plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
plt.xlim([0.0, 1.05])
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.title(title)
|
||||
plt.savefig(out_image)
|
||||
plt.clf()
|
||||
|
||||
def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans,
|
||||
out_image=None, title=None):
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
true_pos = 0.0
|
||||
cur_p = 1.0
|
||||
cur_r = 0.0
|
||||
precisions = [1.0]
|
||||
recalls = [0.0]
|
||||
avg_prec = 0.0
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid_to_has_ans[qid]:
|
||||
true_pos += scores[qid]
|
||||
cur_p = true_pos / float(i+1)
|
||||
cur_r = true_pos / float(num_true_pos)
|
||||
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
|
||||
# i.e., if we can put a threshold after this point
|
||||
avg_prec += cur_p * (cur_r - recalls[-1])
|
||||
precisions.append(cur_p)
|
||||
recalls.append(cur_r)
|
||||
if out_image:
|
||||
plot_pr_curve(precisions, recalls, out_image, title)
|
||||
return {'ap': 100.0 * avg_prec}
|
||||
|
||||
def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
|
||||
qid_to_has_ans, out_image_dir):
|
||||
if out_image_dir and not os.path.exists(out_image_dir):
|
||||
os.makedirs(out_image_dir)
|
||||
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
|
||||
if num_true_pos == 0:
|
||||
return
|
||||
pr_exact = make_precision_recall_eval(
|
||||
exact_raw, na_probs, num_true_pos, qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, 'pr_exact.png'),
|
||||
title='Precision-Recall curve for Exact Match score')
|
||||
pr_f1 = make_precision_recall_eval(
|
||||
f1_raw, na_probs, num_true_pos, qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, 'pr_f1.png'),
|
||||
title='Precision-Recall curve for F1 score')
|
||||
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
|
||||
pr_oracle = make_precision_recall_eval(
|
||||
oracle_scores, na_probs, num_true_pos, qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
|
||||
title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)')
|
||||
merge_eval(main_eval, pr_exact, 'pr_exact')
|
||||
merge_eval(main_eval, pr_f1, 'pr_f1')
|
||||
merge_eval(main_eval, pr_oracle, 'pr_oracle')
|
||||
|
||||
def histogram_na_prob(na_probs, qid_list, image_dir, name):
|
||||
if not qid_list:
|
||||
return
|
||||
x = [na_probs[k] for k in qid_list]
|
||||
weights = np.ones_like(x) / float(len(x))
|
||||
plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
|
||||
plt.xlabel('Model probability of no-answer')
|
||||
plt.ylabel('Proportion of dataset')
|
||||
plt.title('Histogram of no-answer probability: %s' % name)
|
||||
plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name))
|
||||
plt.clf()
|
||||
|
||||
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
||||
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||
cur_score = num_no_ans
|
||||
best_score = cur_score
|
||||
best_thresh = 0.0
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid not in scores: continue
|
||||
if qid_to_has_ans[qid]:
|
||||
diff = scores[qid]
|
||||
else:
|
||||
if preds[qid]:
|
||||
diff = -1
|
||||
else:
|
||||
diff = 0
|
||||
cur_score += diff
|
||||
if cur_score > best_score:
|
||||
best_score = cur_score
|
||||
best_thresh = na_probs[qid]
|
||||
return 100.0 * best_score / len(scores), best_thresh
|
||||
|
||||
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
||||
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||
cur_score = num_no_ans
|
||||
best_score = cur_score
|
||||
best_thresh = 0.0
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid not in scores: continue
|
||||
if qid_to_has_ans[qid]:
|
||||
diff = scores[qid]
|
||||
else:
|
||||
if preds[qid]:
|
||||
diff = -1
|
||||
else:
|
||||
diff = 0
|
||||
cur_score += diff
|
||||
if cur_score > best_score:
|
||||
best_score = cur_score
|
||||
best_thresh = na_probs[qid]
|
||||
|
||||
has_ans_score, has_ans_cnt = 0, 0
|
||||
for qid in qid_list:
|
||||
if not qid_to_has_ans[qid]: continue
|
||||
has_ans_cnt += 1
|
||||
|
||||
if qid not in scores: continue
|
||||
has_ans_score += scores[qid]
|
||||
|
||||
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
|
||||
|
||||
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval['best_exact'] = best_exact
|
||||
main_eval['best_exact_thresh'] = exact_thresh
|
||||
main_eval['best_f1'] = best_f1
|
||||
main_eval['best_f1_thresh'] = f1_thresh
|
||||
|
||||
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval['best_exact'] = best_exact
|
||||
main_eval['best_exact_thresh'] = exact_thresh
|
||||
main_eval['best_f1'] = best_f1
|
||||
main_eval['best_f1_thresh'] = f1_thresh
|
||||
main_eval['has_ans_exact'] = has_ans_exact
|
||||
main_eval['has_ans_f1'] = has_ans_f1
|
||||
|
||||
def main(OPTS):
|
||||
with open(OPTS.data_file) as f:
|
||||
dataset_json = json.load(f)
|
||||
dataset = dataset_json['data']
|
||||
with open(OPTS.pred_file) as f:
|
||||
preds = json.load(f)
|
||||
if OPTS.na_prob_file:
|
||||
with open(OPTS.na_prob_file) as f:
|
||||
na_probs = json.load(f)
|
||||
else:
|
||||
na_probs = {k: 0.0 for k in preds}
|
||||
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
|
||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
exact_raw, f1_raw = get_raw_scores(dataset, preds)
|
||||
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
|
||||
OPTS.na_prob_thresh)
|
||||
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
|
||||
OPTS.na_prob_thresh)
|
||||
out_eval = make_eval_dict(exact_thresh, f1_thresh)
|
||||
if has_ans_qids:
|
||||
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
|
||||
merge_eval(out_eval, has_ans_eval, 'HasAns')
|
||||
if no_ans_qids:
|
||||
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
|
||||
merge_eval(out_eval, no_ans_eval, 'NoAns')
|
||||
if OPTS.na_prob_file:
|
||||
find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
||||
if OPTS.na_prob_file and OPTS.out_image_dir:
|
||||
run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs,
|
||||
qid_to_has_ans, OPTS.out_image_dir)
|
||||
histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns')
|
||||
histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns')
|
||||
if OPTS.out_file:
|
||||
with open(OPTS.out_file, 'w') as f:
|
||||
json.dump(out_eval, f)
|
||||
else:
|
||||
print(json.dumps(out_eval, indent=2))
|
||||
return out_eval
|
||||
|
||||
if __name__ == '__main__':
|
||||
OPTS = parse_args()
|
||||
if OPTS.out_image_dir:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
main(OPTS)
|
@ -578,7 +578,6 @@ def compute_predictions_log_probs(
|
||||
output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file,
|
||||
orig_data_file,
|
||||
start_n_top,
|
||||
end_n_top,
|
||||
version_2_with_negative,
|
||||
@ -756,15 +755,4 @@ def compute_predictions_log_probs(
|
||||
with open(output_null_log_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
with open(orig_data_file, "r", encoding='utf-8') as reader:
|
||||
orig_data = json.load(reader)["data"]
|
||||
|
||||
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
|
||||
out_eval = {}
|
||||
|
||||
find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw, scores_diff_json, qid_to_has_ans)
|
||||
|
||||
return out_eval
|
||||
return all_predictions
|
||||
|
@ -9,7 +9,7 @@ from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||
from .utils import DataProcessor, InputExample, InputFeatures
|
||||
from ...file_utils import is_tf_available, is_torch_available
|
||||
|
||||
if is_torch_available:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user