Misc. fixes for Pytorch QA examples: (#16958)

1. Fixes evaluation errors popping up when you train/eval on squad v2 (one was newly encountered and one that was previously reported Running SQuAD 1.0 sample command raises IndexError #15401 but not completely fixed).
2. Removes boolean arguments that don't use store_true. Please, don't use these: *ANY non-empty string is being converted to True in this case and this clearly is not the desired behavior (and it creates a LOT of confusion).
3. All no-trainer test scripts are now saving metric values in the same way (with the right prefix eval_), which is consistent with the trainer-based versions.
4. Adds forgotten model.eval() in the no-trainer versions. This improved some results, but not everything (see the discussion in the end). Please, see the F1 scores and the discussion below.
This commit is contained in:
Leonid Boytsov 2022-04-27 12:51:39 +00:00 committed by GitHub
parent 49d5bcb0f3
commit c82e017aa9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 105 additions and 16 deletions

View File

@ -158,7 +158,7 @@ def postprocess_qa_predictions(
"end_logit": end_logits[end_index],
}
)
if version_2_with_negative:
if version_2_with_negative and min_null_prediction is not None:
# Add the minimum null prediction
prelim_predictions.append(min_null_prediction)
null_score = min_null_prediction["score"]
@ -167,7 +167,11 @@ def postprocess_qa_predictions(
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
# Add back the minimum null prediction if it was removed because of its low score.
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
if (
version_2_with_negative
and min_null_prediction is not None
and not any(p["offsets"] == (0, 0) for p in predictions)
):
predictions.append(min_null_prediction)
# Use the offsets to gather the answer text in the original context.
@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue
# Don't consider answers with a length negative or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0:
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
# Without predictions min_null_score is going to be None and None will cause an exception later
min_null_score = -2e-6
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).

View File

@ -19,6 +19,7 @@ Fine-tuning XLNet for question answering with beam search using 🤗 Accelerate.
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
import argparse
import json
import logging
import math
import os
@ -60,6 +61,29 @@ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/ques
logger = logging.getLogger(__name__)
def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
"""
Save results while prefixing metric names.
Args:
results: (:obj:`dict`):
A dictionary of results.
output_dir: (:obj:`str`):
An output directory.
file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`):
An output file name.
metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`):
A metric name prefix.
"""
# Prefix all keys with metric_key_prefix + '_'
for key in list(results.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
results[f"{metric_key_prefix}_{key}"] = results.pop(key)
with open(os.path.join(output_dir, file_name), "w") as f:
json.dump(results, f, indent=4)
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
parser.add_argument(
@ -171,8 +195,7 @@ def parse_args():
)
parser.add_argument(
"--version_2_with_negative",
type=bool,
default=False,
action="store_true",
help="If true, some of the examples do not have an answer.",
)
parser.add_argument(
@ -807,6 +830,9 @@ def main():
all_end_top_log_probs = []
all_end_top_index = []
all_cls_logits = []
model.eval()
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
@ -864,6 +890,9 @@ def main():
all_end_top_log_probs = []
all_end_top_index = []
all_cls_logits = []
model.eval()
for step, batch in enumerate(predict_dataloader):
with torch.no_grad():
outputs = model(**batch)
@ -938,6 +967,9 @@ def main():
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
logger.info(json.dumps(eval_metric, indent=4))
save_prefixed_metrics(eval_metric, args.output_dir)
if __name__ == "__main__":
main()

View File

@ -66,6 +66,29 @@ MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
"""
Save results while prefixing metric names.
Args:
results: (:obj:`dict`):
A dictionary of results.
output_dir: (:obj:`str`):
An output directory.
file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`):
An output file name.
metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`):
A metric name prefix.
"""
# Prefix all keys with metric_key_prefix + '_'
for key in list(results.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
results[f"{metric_key_prefix}_{key}"] = results.pop(key)
with open(os.path.join(output_dir, file_name), "w") as f:
json.dump(results, f, indent=4)
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
parser.add_argument(
@ -194,8 +217,7 @@ def parse_args():
)
parser.add_argument(
"--version_2_with_negative",
type=bool,
default=False,
action="store_true",
help="If true, some of the examples do not have an answer.",
)
parser.add_argument(
@ -824,6 +846,9 @@ def main():
all_start_logits = []
all_end_logits = []
model.eval()
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
@ -860,6 +885,9 @@ def main():
all_start_logits = []
all_end_logits = []
model.eval()
for step, batch in enumerate(predict_dataloader):
with torch.no_grad():
outputs = model(**batch)
@ -907,8 +935,9 @@ def main():
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_f1": eval_metric["f1"], "eval_exact": eval_metric["exact"]}, f)
logger.info(json.dumps(eval_metric, indent=4))
save_prefixed_metrics(eval_metric, args.output_dir)
if __name__ == "__main__":

View File

@ -158,7 +158,7 @@ def postprocess_qa_predictions(
"end_logit": end_logits[end_index],
}
)
if version_2_with_negative:
if version_2_with_negative and min_null_prediction is not None:
# Add the minimum null prediction
prelim_predictions.append(min_null_prediction)
null_score = min_null_prediction["score"]
@ -167,7 +167,11 @@ def postprocess_qa_predictions(
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
# Add back the minimum null prediction if it was removed because of its low score.
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
if (
version_2_with_negative
and min_null_prediction is not None
and not any(p["offsets"] == (0, 0) for p in predictions)
):
predictions.append(min_null_prediction)
# Use the offsets to gather the answer text in the original context.
@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue
# Don't consider answers with a length negative or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0:
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
# Without predictions min_null_score is going to be None and None will cause an exception later
min_null_score = -2e-6
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).

View File

@ -200,7 +200,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
testargs = f"""
run_qa_no_trainer.py
--model_name_or_path bert-base-uncased
--version_2_with_negative=False
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
--output_dir {tmp_dir}
@ -216,6 +216,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
with patch.object(sys, "argv", testargs):
run_squad_no_trainer.main()
result = get_results(tmp_dir)
# Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))

View File

@ -158,7 +158,7 @@ def postprocess_qa_predictions(
"end_logit": end_logits[end_index],
}
)
if version_2_with_negative:
if version_2_with_negative and min_null_prediction is not None:
# Add the minimum null prediction
prelim_predictions.append(min_null_prediction)
null_score = min_null_prediction["score"]
@ -167,7 +167,11 @@ def postprocess_qa_predictions(
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
# Add back the minimum null prediction if it was removed because of its low score.
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
if (
version_2_with_negative
and min_null_prediction is not None
and not any(p["offsets"] == (0, 0) for p in predictions)
):
predictions.append(min_null_prediction)
# Use the offsets to gather the answer text in the original context.
@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue
# Don't consider answers with a length negative or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0:
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
# Without predictions min_null_score is going to be None and None will cause an exception later
min_null_score = -2e-6
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).