Replace -100s in predictions by the pad token (#22693)

* Replace -100s in predictions by the pad token

* Style

* Try to catch them all
This commit is contained in:
Sylvain Gugger 2023-04-11 09:32:20 -04:00 committed by GitHub
parent ff73deeb0e
commit 1b1867d86b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 8 deletions

View File

@ -26,6 +26,7 @@ from typing import List, Optional, Tuple
import datasets import datasets
import evaluate import evaluate
import numpy as np
from datasets import load_dataset from datasets import load_dataset
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
@ -614,6 +615,8 @@ def main():
preds = outputs.predictions preds = outputs.predictions
if isinstance(preds, tuple): if isinstance(preds, tuple):
preds = preds[0] preds = preds[0]
# Replace -100s used for padding as we can't decode them
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Build a map example to its corresponding features. # Build a map example to its corresponding features.

View File

@ -632,10 +632,10 @@ def main():
preds, labels = eval_preds preds, labels = eval_preds
if isinstance(preds, tuple): if isinstance(preds, tuple):
preds = preds[0] preds = preds[0]
# Replace -100s used for padding as we can't decode them
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if data_args.ignore_pad_token_for_loss: labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing # Some simple post-processing
@ -714,8 +714,10 @@ def main():
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
if training_args.predict_with_generate: if training_args.predict_with_generate:
predictions = predict_results.predictions
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
predictions = tokenizer.batch_decode( predictions = tokenizer.batch_decode(
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
) )
predictions = [pred.strip() for pred in predictions] predictions = [pred.strip() for pred in predictions]
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")

View File

@ -543,10 +543,10 @@ def main():
preds, labels = eval_preds preds, labels = eval_preds
if isinstance(preds, tuple): if isinstance(preds, tuple):
preds = preds[0] preds = preds[0]
# Replace -100s used for padding as we can't decode them
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if data_args.ignore_pad_token_for_loss: labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing # Some simple post-processing
@ -626,8 +626,10 @@ def main():
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
if training_args.predict_with_generate: if training_args.predict_with_generate:
predictions = predict_results.predictions
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
predictions = tokenizer.batch_decode( predictions = tokenizer.batch_decode(
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
) )
predictions = [pred.strip() for pred in predictions] predictions = [pred.strip() for pred in predictions]
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")