diff --git a/examples/pytorch/question-answering/trainer_qa.py b/examples/pytorch/question-answering/trainer_qa.py index 702d8ac6abb..7f98eba236c 100644 --- a/examples/pytorch/question-answering/trainer_qa.py +++ b/examples/pytorch/question-answering/trainer_qa.py @@ -31,7 +31,7 @@ class QuestionAnsweringTrainer(Trainer): self.eval_examples = eval_examples self.post_process_function = post_process_function - def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None): + def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"): eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_examples = self.eval_examples if eval_examples is None else eval_examples @@ -56,6 +56,11 @@ class QuestionAnsweringTrainer(Trainer): eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) metrics = self.compute_metrics(eval_preds) + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + self.log(metrics) else: metrics = {} @@ -67,7 +72,7 @@ class QuestionAnsweringTrainer(Trainer): self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) return metrics - def predict(self, predict_dataset, predict_examples, ignore_keys=None): + def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"): predict_dataloader = self.get_test_dataloader(predict_dataset) # Temporarily disable metric computation, we will do it in the loop here. @@ -92,4 +97,9 @@ class QuestionAnsweringTrainer(Trainer): predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") metrics = self.compute_metrics(predictions) + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 717bca47c67..74f1cb28c1e 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -213,7 +213,7 @@ class ExamplesTests(TestCasePlus): tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" - run_squad.py + run_qa.py --model_name_or_path bert-base-uncased --version_2_with_negative --train_file tests/fixtures/tests_samples/SQUAD/sample.json @@ -232,8 +232,8 @@ class ExamplesTests(TestCasePlus): with patch.object(sys, "argv", testargs): run_squad.main() result = get_results(tmp_dir) - self.assertGreaterEqual(result["f1"], 30) - self.assertGreaterEqual(result["exact"], 30) + self.assertGreaterEqual(result["eval_f1"], 30) + self.assertGreaterEqual(result["eval_exact"], 30) def test_run_swag(self): stream_handler = logging.StreamHandler(sys.stdout)