diff --git a/examples/flax/question-answering/utils_qa.py b/examples/flax/question-answering/utils_qa.py index 82b935f86f3..fd0bc16f7e4 100644 --- a/examples/flax/question-answering/utils_qa.py +++ b/examples/flax/question-answering/utils_qa.py @@ -137,7 +137,9 @@ def postprocess_qa_predictions( start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or offset_mapping[start_index] is None + or len(offset_mapping[start_index]) < 2 or offset_mapping[end_index] is None + or len(offset_mapping[end_index]) < 2 ): continue # Don't consider answers with a length that is either < 0 or > max_answer_length. @@ -147,6 +149,7 @@ def postprocess_qa_predictions( # provided). if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): continue + prelim_predictions.append( { "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), diff --git a/examples/pytorch/question-answering/utils_qa.py b/examples/pytorch/question-answering/utils_qa.py index 82b935f86f3..fd0bc16f7e4 100644 --- a/examples/pytorch/question-answering/utils_qa.py +++ b/examples/pytorch/question-answering/utils_qa.py @@ -137,7 +137,9 @@ def postprocess_qa_predictions( start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or offset_mapping[start_index] is None + or len(offset_mapping[start_index]) < 2 or offset_mapping[end_index] is None + or len(offset_mapping[end_index]) < 2 ): continue # Don't consider answers with a length that is either < 0 or > max_answer_length. @@ -147,6 +149,7 @@ def postprocess_qa_predictions( # provided). if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): continue + prelim_predictions.append( { "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), diff --git a/examples/tensorflow/question-answering/utils_qa.py b/examples/tensorflow/question-answering/utils_qa.py index 82b935f86f3..fd0bc16f7e4 100644 --- a/examples/tensorflow/question-answering/utils_qa.py +++ b/examples/tensorflow/question-answering/utils_qa.py @@ -137,7 +137,9 @@ def postprocess_qa_predictions( start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or offset_mapping[start_index] is None + or len(offset_mapping[start_index]) < 2 or offset_mapping[end_index] is None + or len(offset_mapping[end_index]) < 2 ): continue # Don't consider answers with a length that is either < 0 or > max_answer_length. @@ -147,6 +149,7 @@ def postprocess_qa_predictions( # provided). if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): continue + prelim_predictions.append( { "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),