Fix arguments passed to predict function in QA Seq2seq training script (#21026)

fix args passed to predict function
This commit is contained in:
Observer46 2023-01-06 13:19:42 +01:00 committed by GitHub
parent 35a7052b61
commit ff8dcb5efa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -151,7 +151,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
if self.post_process_function is None or self.compute_metrics is None: if self.post_process_function is None or self.compute_metrics is None:
return output return output
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") predictions = self.post_process_function(predict_examples, predict_dataset, output, "predict")
metrics = self.compute_metrics(predictions) metrics = self.compute_metrics(predictions)
# Prefix all keys with metric_key_prefix + '_' # Prefix all keys with metric_key_prefix + '_'