From 936b57158ad2641390422274fed6ee6c2a685e15 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 17 May 2021 10:10:13 -0400 Subject: [PATCH] Use new evaluation loop in TrainerQA (#11746) --- examples/pytorch/question-answering/trainer_qa.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/question-answering/trainer_qa.py b/examples/pytorch/question-answering/trainer_qa.py index 36e2e544a7a..702d8ac6abb 100644 --- a/examples/pytorch/question-answering/trainer_qa.py +++ b/examples/pytorch/question-answering/trainer_qa.py @@ -39,8 +39,9 @@ class QuestionAnsweringTrainer(Trainer): # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop try: - output = self.prediction_loop( + output = eval_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to @@ -72,8 +73,9 @@ class QuestionAnsweringTrainer(Trainer): # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop try: - output = self.prediction_loop( + output = eval_loop( predict_dataloader, description="Prediction", # No point gathering the predictions if there are no metrics, otherwise we defer to