mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[s2s examples] Replace -100 token ids with the tokenizer pad_id for compute_metrics (#10046)
* replace -100 token ids with the tokenizer pad_id for compute_metrics * fixed typo for label_ids
This commit is contained in:
parent
c9df1b1d53
commit
ece6c51458
@ -82,8 +82,11 @@ def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) ->
|
||||
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||
|
||||
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
||||
pred_ids = pred.predictions
|
||||
label_ids = pred.label_ids
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
||||
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
||||
pred_str = lmap(str.strip, pred_str)
|
||||
label_str = lmap(str.strip, label_str)
|
||||
return pred_str, label_str
|
||||
|
Loading…
Reference in New Issue
Block a user