Fix min_null_pred in the run_qa script (#9067)

This commit is contained in:
Sylvain Gugger 2020-12-11 16:26:05 -05:00 committed by GitHub
parent 9cc9f4122e
commit 29e4597950
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -76,9 +76,7 @@ def postprocess_qa_predictions(
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
all_start_logits, all_end_logits = predictions
assert len(predictions[0]) == len(
features
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
@ -118,7 +116,7 @@ def postprocess_qa_predictions(
# Update minimum null prediction.
feature_null_score = start_logits[0] + end_logits[0]
if min_null_prediction is None or min_null_prediction["score"] < feature_null_score:
if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
min_null_prediction = {
"offsets": (0, 0),
"score": feature_null_score,