Harder check for IndexErrors in QA scripts (#15438)

* Harder check for IndexErrors in QA scripts

* Make test stronger
This commit is contained in:
Sylvain Gugger 2022-02-01 15:49:13 -05:00 committed by GitHub
parent 8e5d4e4906
commit d0b5ed110a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 0 deletions

View File

@ -137,7 +137,9 @@ def postprocess_qa_predictions(
start_index >= len(offset_mapping) start_index >= len(offset_mapping)
or end_index >= len(offset_mapping) or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
): ):
continue continue
# Don't consider answers with a length that is either < 0 or > max_answer_length. # Don't consider answers with a length that is either < 0 or > max_answer_length.
@ -147,6 +149,7 @@ def postprocess_qa_predictions(
# provided). # provided).
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
continue continue
prelim_predictions.append( prelim_predictions.append(
{ {
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),

View File

@ -137,7 +137,9 @@ def postprocess_qa_predictions(
start_index >= len(offset_mapping) start_index >= len(offset_mapping)
or end_index >= len(offset_mapping) or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
): ):
continue continue
# Don't consider answers with a length that is either < 0 or > max_answer_length. # Don't consider answers with a length that is either < 0 or > max_answer_length.
@ -147,6 +149,7 @@ def postprocess_qa_predictions(
# provided). # provided).
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
continue continue
prelim_predictions.append( prelim_predictions.append(
{ {
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),

View File

@ -137,7 +137,9 @@ def postprocess_qa_predictions(
start_index >= len(offset_mapping) start_index >= len(offset_mapping)
or end_index >= len(offset_mapping) or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
): ):
continue continue
# Don't consider answers with a length that is either < 0 or > max_answer_length. # Don't consider answers with a length that is either < 0 or > max_answer_length.
@ -147,6 +149,7 @@ def postprocess_qa_predictions(
# provided). # provided).
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
continue continue
prelim_predictions.append( prelim_predictions.append(
{ {
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),