[s2s examples] Replace -100 token ids with the tokenizer pad_id for compute_metrics (#10046)

* replace -100 token ids with the tokenizer pad_id for compute_metrics

* fixed typo for label_ids
This commit is contained in:
Olivier 2021-02-08 10:08:16 -05:00 committed by GitHub
parent c9df1b1d53
commit ece6c51458
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -82,8 +82,11 @@ def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) ->
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
pred_ids = pred.predictions
label_ids = pred.label_ids
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_ids[label_ids == -100] = tokenizer.pad_token_id
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str