From ae6834e028ecdf7fdbe886c1f86d0e02d5fef6f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 31 Mar 2020 17:54:13 +0200 Subject: [PATCH] [Examples] Clean summarization and translation example testing files for T5 and Bart (#3514) * fix conflicts * add model size argument to summarization * correct wrong import * fix isort * correct imports * other isort make style * make style --- examples/summarization/bart/evaluate_cnn.py | 4 ++-- examples/summarization/bart/test_bart_examples.py | 13 ++++++------- examples/summarization/t5/evaluate_cnn.py | 2 +- examples/summarization/t5/test_t5_examples.py | 10 ++++++---- examples/translation/t5/evaluate_wmt.py | 14 ++++++++++---- examples/translation/t5/test_t5_examples.py | 15 +++++++++++---- 6 files changed, 36 insertions(+), 22 deletions(-) diff --git a/examples/summarization/bart/evaluate_cnn.py b/examples/summarization/bart/evaluate_cnn.py index 0903e0c0f98..fe682257e31 100644 --- a/examples/summarization/bart/evaluate_cnn.py +++ b/examples/summarization/bart/evaluate_cnn.py @@ -45,7 +45,7 @@ def generate_summaries( fout.flush() -def _run_generate(): +def run_generate(): parser = argparse.ArgumentParser() parser.add_argument( "source_path", type=str, help="like cnn_dm/test.source", @@ -68,4 +68,4 @@ def _run_generate(): if __name__ == "__main__": - _run_generate() + run_generate() diff --git a/examples/summarization/bart/test_bart_examples.py b/examples/summarization/bart/test_bart_examples.py index b1d1d8e756b..40be3b5668d 100644 --- a/examples/summarization/bart/test_bart_examples.py +++ b/examples/summarization/bart/test_bart_examples.py @@ -1,16 +1,13 @@ import logging -import os import sys import tempfile import unittest from pathlib import Path from unittest.mock import patch -from .evaluate_cnn import _run_generate +from .evaluate_cnn import run_generate -output_file_name = "output_bart_sum.txt" - articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] logging.basicConfig(level=logging.DEBUG) @@ -26,8 +23,10 @@ class TestBartExamples(unittest.TestCase): with tmp.open("w") as f: f.write("\n".join(articles)) - testargs = ["evaluate_cnn.py", str(tmp), output_file_name, "sshleifer/bart-tiny-random"] + output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo" + + testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"] + with patch.object(sys, "argv", testargs): - _run_generate() + run_generate() self.assertTrue(Path(output_file_name).exists()) - os.remove(Path(output_file_name)) diff --git a/examples/summarization/t5/evaluate_cnn.py b/examples/summarization/t5/evaluate_cnn.py index 535c11093b6..3c923a46d7a 100644 --- a/examples/summarization/t5/evaluate_cnn.py +++ b/examples/summarization/t5/evaluate_cnn.py @@ -64,7 +64,7 @@ def run_generate(): parser.add_argument( "model_size", type=str, - help="T5 model size, either 't5-small', 't5-base' or 't5-large'. Defaults to base.", + help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.", default="t5-base", ) parser.add_argument( diff --git a/examples/summarization/t5/test_t5_examples.py b/examples/summarization/t5/test_t5_examples.py index 57f3e342d77..cf26ae88920 100644 --- a/examples/summarization/t5/test_t5_examples.py +++ b/examples/summarization/t5/test_t5_examples.py @@ -1,5 +1,4 @@ import logging -import os import sys import tempfile import unittest @@ -26,10 +25,13 @@ class TestT5Examples(unittest.TestCase): tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo" with tmp.open("w") as f: f.write("\n".join(articles)) - testargs = ["evaluate_cnn.py", "t5-small", str(tmp), output_file_name, str(tmp), score_file_name] + + output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo" + score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo" + + testargs = ["evaluate_cnn.py", "t5-small", str(tmp), str(output_file_name), str(tmp), str(score_file_name)] + with patch.object(sys, "argv", testargs): run_generate() self.assertTrue(Path(output_file_name).exists()) self.assertTrue(Path(score_file_name).exists()) - os.remove(Path(output_file_name)) - os.remove(Path(score_file_name)) diff --git a/examples/translation/t5/evaluate_wmt.py b/examples/translation/t5/evaluate_wmt.py index 307065d0a99..533811271b7 100644 --- a/examples/translation/t5/evaluate_wmt.py +++ b/examples/translation/t5/evaluate_wmt.py @@ -14,13 +14,13 @@ def chunks(lst, n): yield lst[i : i + n] -def generate_translations(lns, output_file_path, batch_size, device): +def generate_translations(lns, output_file_path, model_size, batch_size, device): output_file = Path(output_file_path).open("w") - model = T5ForConditionalGeneration.from_pretrained("t5-base") + model = T5ForConditionalGeneration.from_pretrained(model_size) model.to(device) - tokenizer = T5Tokenizer.from_pretrained("t5-base") + tokenizer = T5Tokenizer.from_pretrained(model_size) # update config with summarization specific params task_specific_params = model.config.task_specific_params @@ -52,6 +52,12 @@ def calculate_bleu_score(output_lns, refs_lns, score_path): def run_generate(): parser = argparse.ArgumentParser() + parser.add_argument( + "model_size", + type=str, + help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.", + default="t5-base", + ) parser.add_argument( "input_path", type=str, help="like wmt/newstest2013.en", ) @@ -78,7 +84,7 @@ def run_generate(): input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()] - generate_translations(input_lns, args.output_path, args.batch_size, args.device) + generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device) output_lns = [x.strip() for x in open(args.output_path).readlines()] refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()] diff --git a/examples/translation/t5/test_t5_examples.py b/examples/translation/t5/test_t5_examples.py index eea17c227a1..6f548e5eb2b 100644 --- a/examples/translation/t5/test_t5_examples.py +++ b/examples/translation/t5/test_t5_examples.py @@ -1,5 +1,4 @@ import logging -import os import sys import tempfile import unittest @@ -33,11 +32,19 @@ class TestT5Examples(unittest.TestCase): with tmp_target.open("w") as f: f.write("\n".join(translation)) - testargs = ["evaluate_wmt.py", str(tmp_source), output_file_name, str(tmp_target), score_file_name] + output_file_name = Path(tempfile.gettempdir()) / "utest_output_trans.hypo" + score_file_name = Path(tempfile.gettempdir()) / "utest_score.hypo" + + testargs = [ + "evaluate_wmt.py", + "t5-small", + str(tmp_source), + str(output_file_name), + str(tmp_target), + str(score_file_name), + ] with patch.object(sys, "argv", testargs): run_generate() self.assertTrue(Path(output_file_name).exists()) self.assertTrue(Path(score_file_name).exists()) - os.remove(Path(output_file_name)) - os.remove(Path(score_file_name))