fixed JSON error in run_qa with fp16 (#9186)

This commit is contained in:
Wissam Antoun 2020-12-18 14:53:23 +02:00 committed by GitHub
parent 66a14a2f6f
commit fd7b6a5274
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -206,7 +206,7 @@ def postprocess_qa_predictions(
# Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions
]
@ -394,7 +394,7 @@ def postprocess_qa_predictions_with_beam_search(
# Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions
]