From ff8dcb5efa6816cc3e838cbf9ebc49becfb83c95 Mon Sep 17 00:00:00 2001 From: Observer46 Date: Fri, 6 Jan 2023 13:19:42 +0100 Subject: [PATCH] Fix arguments passed to predict function in QA Seq2seq training script (#21026) fix args passed to predict function --- examples/pytorch/question-answering/trainer_seq2seq_qa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py index 73517c06d7c..6abb41b33fe 100644 --- a/examples/pytorch/question-answering/trainer_seq2seq_qa.py +++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py @@ -151,7 +151,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): if self.post_process_function is None or self.compute_metrics is None: return output - predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") + predictions = self.post_process_function(predict_examples, predict_dataset, output, "predict") metrics = self.compute_metrics(predictions) # Prefix all keys with metric_key_prefix + '_'