mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Misc. fixes for Pytorch QA examples: (#16958)
1. Fixes evaluation errors popping up when you train/eval on squad v2 (one was newly encountered and one that was previously reported Running SQuAD 1.0 sample command raises IndexError #15401 but not completely fixed). 2. Removes boolean arguments that don't use store_true. Please, don't use these: *ANY non-empty string is being converted to True in this case and this clearly is not the desired behavior (and it creates a LOT of confusion). 3. All no-trainer test scripts are now saving metric values in the same way (with the right prefix eval_), which is consistent with the trainer-based versions. 4. Adds forgotten model.eval() in the no-trainer versions. This improved some results, but not everything (see the discussion in the end). Please, see the F1 scores and the discussion below.
This commit is contained in:
parent
49d5bcb0f3
commit
c82e017aa9
@ -158,7 +158,7 @@ def postprocess_qa_predictions(
|
|||||||
"end_logit": end_logits[end_index],
|
"end_logit": end_logits[end_index],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if version_2_with_negative:
|
if version_2_with_negative and min_null_prediction is not None:
|
||||||
# Add the minimum null prediction
|
# Add the minimum null prediction
|
||||||
prelim_predictions.append(min_null_prediction)
|
prelim_predictions.append(min_null_prediction)
|
||||||
null_score = min_null_prediction["score"]
|
null_score = min_null_prediction["score"]
|
||||||
@ -167,7 +167,11 @@ def postprocess_qa_predictions(
|
|||||||
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
||||||
|
|
||||||
# Add back the minimum null prediction if it was removed because of its low score.
|
# Add back the minimum null prediction if it was removed because of its low score.
|
||||||
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
|
if (
|
||||||
|
version_2_with_negative
|
||||||
|
and min_null_prediction is not None
|
||||||
|
and not any(p["offsets"] == (0, 0) for p in predictions)
|
||||||
|
):
|
||||||
predictions.append(min_null_prediction)
|
predictions.append(min_null_prediction)
|
||||||
|
|
||||||
# Use the offsets to gather the answer text in the original context.
|
# Use the offsets to gather the answer text in the original context.
|
||||||
@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
start_index >= len(offset_mapping)
|
start_index >= len(offset_mapping)
|
||||||
or end_index >= len(offset_mapping)
|
or end_index >= len(offset_mapping)
|
||||||
or offset_mapping[start_index] is None
|
or offset_mapping[start_index] is None
|
||||||
|
or len(offset_mapping[start_index]) < 2
|
||||||
or offset_mapping[end_index] is None
|
or offset_mapping[end_index] is None
|
||||||
|
or len(offset_mapping[end_index]) < 2
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Don't consider answers with a length negative or > max_answer_length.
|
# Don't consider answers with a length negative or > max_answer_length.
|
||||||
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
||||||
continue
|
continue
|
||||||
@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
||||||
# failure.
|
# failure.
|
||||||
if len(predictions) == 0:
|
if len(predictions) == 0:
|
||||||
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
|
# Without predictions min_null_score is going to be None and None will cause an exception later
|
||||||
|
min_null_score = -2e-6
|
||||||
|
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
|
||||||
|
|
||||||
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
||||||
# the LogSumExp trick).
|
# the LogSumExp trick).
|
||||||
|
@ -19,6 +19,7 @@ Fine-tuning XLNet for question answering with beam search using 🤗 Accelerate.
|
|||||||
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
|
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -60,6 +61,29 @@ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/ques
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
|
||||||
|
"""
|
||||||
|
Save results while prefixing metric names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: (:obj:`dict`):
|
||||||
|
A dictionary of results.
|
||||||
|
output_dir: (:obj:`str`):
|
||||||
|
An output directory.
|
||||||
|
file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`):
|
||||||
|
An output file name.
|
||||||
|
metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`):
|
||||||
|
A metric name prefix.
|
||||||
|
"""
|
||||||
|
# Prefix all keys with metric_key_prefix + '_'
|
||||||
|
for key in list(results.keys()):
|
||||||
|
if not key.startswith(f"{metric_key_prefix}_"):
|
||||||
|
results[f"{metric_key_prefix}_{key}"] = results.pop(key)
|
||||||
|
|
||||||
|
with open(os.path.join(output_dir, file_name), "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
|
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -171,8 +195,7 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--version_2_with_negative",
|
"--version_2_with_negative",
|
||||||
type=bool,
|
action="store_true",
|
||||||
default=False,
|
|
||||||
help="If true, some of the examples do not have an answer.",
|
help="If true, some of the examples do not have an answer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -807,6 +830,9 @@ def main():
|
|||||||
all_end_top_log_probs = []
|
all_end_top_log_probs = []
|
||||||
all_end_top_index = []
|
all_end_top_index = []
|
||||||
all_cls_logits = []
|
all_cls_logits = []
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
for step, batch in enumerate(eval_dataloader):
|
for step, batch in enumerate(eval_dataloader):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
@ -864,6 +890,9 @@ def main():
|
|||||||
all_end_top_log_probs = []
|
all_end_top_log_probs = []
|
||||||
all_end_top_index = []
|
all_end_top_index = []
|
||||||
all_cls_logits = []
|
all_cls_logits = []
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
for step, batch in enumerate(predict_dataloader):
|
for step, batch in enumerate(predict_dataloader):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
@ -938,6 +967,9 @@ def main():
|
|||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
|
|
||||||
|
logger.info(json.dumps(eval_metric, indent=4))
|
||||||
|
save_prefixed_metrics(eval_metric, args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -66,6 +66,29 @@ MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
|||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|
||||||
|
|
||||||
|
def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
|
||||||
|
"""
|
||||||
|
Save results while prefixing metric names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: (:obj:`dict`):
|
||||||
|
A dictionary of results.
|
||||||
|
output_dir: (:obj:`str`):
|
||||||
|
An output directory.
|
||||||
|
file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`):
|
||||||
|
An output file name.
|
||||||
|
metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`):
|
||||||
|
A metric name prefix.
|
||||||
|
"""
|
||||||
|
# Prefix all keys with metric_key_prefix + '_'
|
||||||
|
for key in list(results.keys()):
|
||||||
|
if not key.startswith(f"{metric_key_prefix}_"):
|
||||||
|
results[f"{metric_key_prefix}_{key}"] = results.pop(key)
|
||||||
|
|
||||||
|
with open(os.path.join(output_dir, file_name), "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
|
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -194,8 +217,7 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--version_2_with_negative",
|
"--version_2_with_negative",
|
||||||
type=bool,
|
action="store_true",
|
||||||
default=False,
|
|
||||||
help="If true, some of the examples do not have an answer.",
|
help="If true, some of the examples do not have an answer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -824,6 +846,9 @@ def main():
|
|||||||
|
|
||||||
all_start_logits = []
|
all_start_logits = []
|
||||||
all_end_logits = []
|
all_end_logits = []
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
for step, batch in enumerate(eval_dataloader):
|
for step, batch in enumerate(eval_dataloader):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
@ -860,6 +885,9 @@ def main():
|
|||||||
|
|
||||||
all_start_logits = []
|
all_start_logits = []
|
||||||
all_end_logits = []
|
all_end_logits = []
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
for step, batch in enumerate(predict_dataloader):
|
for step, batch in enumerate(predict_dataloader):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
@ -907,8 +935,9 @@ def main():
|
|||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
|
||||||
json.dump({"eval_f1": eval_metric["f1"], "eval_exact": eval_metric["exact"]}, f)
|
logger.info(json.dumps(eval_metric, indent=4))
|
||||||
|
save_prefixed_metrics(eval_metric, args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -158,7 +158,7 @@ def postprocess_qa_predictions(
|
|||||||
"end_logit": end_logits[end_index],
|
"end_logit": end_logits[end_index],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if version_2_with_negative:
|
if version_2_with_negative and min_null_prediction is not None:
|
||||||
# Add the minimum null prediction
|
# Add the minimum null prediction
|
||||||
prelim_predictions.append(min_null_prediction)
|
prelim_predictions.append(min_null_prediction)
|
||||||
null_score = min_null_prediction["score"]
|
null_score = min_null_prediction["score"]
|
||||||
@ -167,7 +167,11 @@ def postprocess_qa_predictions(
|
|||||||
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
||||||
|
|
||||||
# Add back the minimum null prediction if it was removed because of its low score.
|
# Add back the minimum null prediction if it was removed because of its low score.
|
||||||
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
|
if (
|
||||||
|
version_2_with_negative
|
||||||
|
and min_null_prediction is not None
|
||||||
|
and not any(p["offsets"] == (0, 0) for p in predictions)
|
||||||
|
):
|
||||||
predictions.append(min_null_prediction)
|
predictions.append(min_null_prediction)
|
||||||
|
|
||||||
# Use the offsets to gather the answer text in the original context.
|
# Use the offsets to gather the answer text in the original context.
|
||||||
@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
start_index >= len(offset_mapping)
|
start_index >= len(offset_mapping)
|
||||||
or end_index >= len(offset_mapping)
|
or end_index >= len(offset_mapping)
|
||||||
or offset_mapping[start_index] is None
|
or offset_mapping[start_index] is None
|
||||||
|
or len(offset_mapping[start_index]) < 2
|
||||||
or offset_mapping[end_index] is None
|
or offset_mapping[end_index] is None
|
||||||
|
or len(offset_mapping[end_index]) < 2
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Don't consider answers with a length negative or > max_answer_length.
|
# Don't consider answers with a length negative or > max_answer_length.
|
||||||
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
||||||
continue
|
continue
|
||||||
@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
||||||
# failure.
|
# failure.
|
||||||
if len(predictions) == 0:
|
if len(predictions) == 0:
|
||||||
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
|
# Without predictions min_null_score is going to be None and None will cause an exception later
|
||||||
|
min_null_score = -2e-6
|
||||||
|
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
|
||||||
|
|
||||||
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
||||||
# the LogSumExp trick).
|
# the LogSumExp trick).
|
||||||
|
@ -200,7 +200,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_qa_no_trainer.py
|
run_qa_no_trainer.py
|
||||||
--model_name_or_path bert-base-uncased
|
--model_name_or_path bert-base-uncased
|
||||||
--version_2_with_negative=False
|
--version_2_with_negative
|
||||||
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
--output_dir {tmp_dir}
|
--output_dir {tmp_dir}
|
||||||
@ -216,6 +216,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_squad_no_trainer.main()
|
run_squad_no_trainer.main()
|
||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
|
# Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
|
||||||
self.assertGreaterEqual(result["eval_f1"], 30)
|
self.assertGreaterEqual(result["eval_f1"], 30)
|
||||||
self.assertGreaterEqual(result["eval_exact"], 30)
|
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
@ -158,7 +158,7 @@ def postprocess_qa_predictions(
|
|||||||
"end_logit": end_logits[end_index],
|
"end_logit": end_logits[end_index],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if version_2_with_negative:
|
if version_2_with_negative and min_null_prediction is not None:
|
||||||
# Add the minimum null prediction
|
# Add the minimum null prediction
|
||||||
prelim_predictions.append(min_null_prediction)
|
prelim_predictions.append(min_null_prediction)
|
||||||
null_score = min_null_prediction["score"]
|
null_score = min_null_prediction["score"]
|
||||||
@ -167,7 +167,11 @@ def postprocess_qa_predictions(
|
|||||||
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
||||||
|
|
||||||
# Add back the minimum null prediction if it was removed because of its low score.
|
# Add back the minimum null prediction if it was removed because of its low score.
|
||||||
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
|
if (
|
||||||
|
version_2_with_negative
|
||||||
|
and min_null_prediction is not None
|
||||||
|
and not any(p["offsets"] == (0, 0) for p in predictions)
|
||||||
|
):
|
||||||
predictions.append(min_null_prediction)
|
predictions.append(min_null_prediction)
|
||||||
|
|
||||||
# Use the offsets to gather the answer text in the original context.
|
# Use the offsets to gather the answer text in the original context.
|
||||||
@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
start_index >= len(offset_mapping)
|
start_index >= len(offset_mapping)
|
||||||
or end_index >= len(offset_mapping)
|
or end_index >= len(offset_mapping)
|
||||||
or offset_mapping[start_index] is None
|
or offset_mapping[start_index] is None
|
||||||
|
or len(offset_mapping[start_index]) < 2
|
||||||
or offset_mapping[end_index] is None
|
or offset_mapping[end_index] is None
|
||||||
|
or len(offset_mapping[end_index]) < 2
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Don't consider answers with a length negative or > max_answer_length.
|
# Don't consider answers with a length negative or > max_answer_length.
|
||||||
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
||||||
continue
|
continue
|
||||||
@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
||||||
# failure.
|
# failure.
|
||||||
if len(predictions) == 0:
|
if len(predictions) == 0:
|
||||||
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
|
# Without predictions min_null_score is going to be None and None will cause an exception later
|
||||||
|
min_null_score = -2e-6
|
||||||
|
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
|
||||||
|
|
||||||
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
||||||
# the LogSumExp trick).
|
# the LogSumExp trick).
|
||||||
|
Loading…
Reference in New Issue
Block a user