fix: Update run_qa.py to work with deepset/germanquad (#23225)

Call str on id to make sure any ints are converted into the expected format for squad datasets
This commit is contained in:
Sebastian 2023-05-09 15:20:10 +02:00 committed by GitHub
parent 51ae566511
commit 1a8f61110e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -590,12 +590,12 @@ def main():
# Format the result to the format the metric expects.
if data_args.version_2_with_negative:
formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
{"id": str(k), "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
formatted_predictions = [{"id": str(k), "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")