mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fixed JSON error in run_qa with fp16 (#9186)
This commit is contained in:
parent
66a14a2f6f
commit
fd7b6a5274
@ -206,7 +206,7 @@ def postprocess_qa_predictions(
|
||||
|
||||
# Make `predictions` JSON-serializable by casting np.float back to float.
|
||||
all_nbest_json[example["id"]] = [
|
||||
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
for pred in predictions
|
||||
]
|
||||
|
||||
@ -394,7 +394,7 @@ def postprocess_qa_predictions_with_beam_search(
|
||||
|
||||
# Make `predictions` JSON-serializable by casting np.float back to float.
|
||||
all_nbest_json[example["id"]] = [
|
||||
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
for pred in predictions
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user