From d0b5ed110aa034f21de3a1b4c043b44a7dc59c41 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 1 Feb 2022 15:49:13 -0500 Subject: [PATCH] Harder check for IndexErrors in QA scripts (#15438) * Harder check for IndexErrors in QA scripts * Make test stronger --- examples/flax/question-answering/utils_qa.py | 3 +++ examples/pytorch/question-answering/utils_qa.py | 3 +++ examples/tensorflow/question-answering/utils_qa.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/examples/flax/question-answering/utils_qa.py b/examples/flax/question-answering/utils_qa.py index 82b935f86f3..fd0bc16f7e4 100644 --- a/examples/flax/question-answering/utils_qa.py +++ b/examples/flax/question-answering/utils_qa.py @@ -137,7 +137,9 @@ def postprocess_qa_predictions( start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or offset_mapping[start_index] is None + or len(offset_mapping[start_index]) < 2 or offset_mapping[end_index] is None + or len(offset_mapping[end_index]) < 2 ): continue # Don't consider answers with a length that is either < 0 or > max_answer_length. @@ -147,6 +149,7 @@ def postprocess_qa_predictions( # provided). if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): continue + prelim_predictions.append( { "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), diff --git a/examples/pytorch/question-answering/utils_qa.py b/examples/pytorch/question-answering/utils_qa.py index 82b935f86f3..fd0bc16f7e4 100644 --- a/examples/pytorch/question-answering/utils_qa.py +++ b/examples/pytorch/question-answering/utils_qa.py @@ -137,7 +137,9 @@ def postprocess_qa_predictions( start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or offset_mapping[start_index] is None + or len(offset_mapping[start_index]) < 2 or offset_mapping[end_index] is None + or len(offset_mapping[end_index]) < 2 ): continue # Don't consider answers with a length that is either < 0 or > max_answer_length. @@ -147,6 +149,7 @@ def postprocess_qa_predictions( # provided). if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): continue + prelim_predictions.append( { "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), diff --git a/examples/tensorflow/question-answering/utils_qa.py b/examples/tensorflow/question-answering/utils_qa.py index 82b935f86f3..fd0bc16f7e4 100644 --- a/examples/tensorflow/question-answering/utils_qa.py +++ b/examples/tensorflow/question-answering/utils_qa.py @@ -137,7 +137,9 @@ def postprocess_qa_predictions( start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or offset_mapping[start_index] is None + or len(offset_mapping[start_index]) < 2 or offset_mapping[end_index] is None + or len(offset_mapping[end_index]) < 2 ): continue # Don't consider answers with a length that is either < 0 or > max_answer_length. @@ -147,6 +149,7 @@ def postprocess_qa_predictions( # provided). if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): continue + prelim_predictions.append( { "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),