[Examples] Clean summarization and translation example testing files for T5 and Bart (#3514)

* fix conflicts

* add model size argument to summarization

* correct wrong import

* fix isort

* correct imports

* other isort make style

* make style
This commit is contained in:
Patrick von Platen 2020-03-31 17:54:13 +02:00 committed by GitHub
parent 0373b60c4c
commit ae6834e028
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 22 deletions

View File

@ -45,7 +45,7 @@ def generate_summaries(
fout.flush() fout.flush()
def _run_generate(): def run_generate():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"source_path", type=str, help="like cnn_dm/test.source", "source_path", type=str, help="like cnn_dm/test.source",
@ -68,4 +68,4 @@ def _run_generate():
if __name__ == "__main__": if __name__ == "__main__":
_run_generate() run_generate()

View File

@ -1,16 +1,13 @@
import logging import logging
import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
from .evaluate_cnn import _run_generate from .evaluate_cnn import run_generate
output_file_name = "output_bart_sum.txt"
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -26,8 +23,10 @@ class TestBartExamples(unittest.TestCase):
with tmp.open("w") as f: with tmp.open("w") as f:
f.write("\n".join(articles)) f.write("\n".join(articles))
testargs = ["evaluate_cnn.py", str(tmp), output_file_name, "sshleifer/bart-tiny-random"] output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
_run_generate() run_generate()
self.assertTrue(Path(output_file_name).exists()) self.assertTrue(Path(output_file_name).exists())
os.remove(Path(output_file_name))

View File

@ -64,7 +64,7 @@ def run_generate():
parser.add_argument( parser.add_argument(
"model_size", "model_size",
type=str, type=str,
help="T5 model size, either 't5-small', 't5-base' or 't5-large'. Defaults to base.", help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
default="t5-base", default="t5-base",
) )
parser.add_argument( parser.add_argument(

View File

@ -1,5 +1,4 @@
import logging import logging
import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
@ -26,10 +25,13 @@ class TestT5Examples(unittest.TestCase):
tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo" tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
with tmp.open("w") as f: with tmp.open("w") as f:
f.write("\n".join(articles)) f.write("\n".join(articles))
testargs = ["evaluate_cnn.py", "t5-small", str(tmp), output_file_name, str(tmp), score_file_name]
output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"
testargs = ["evaluate_cnn.py", "t5-small", str(tmp), str(output_file_name), str(tmp), str(score_file_name)]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
self.assertTrue(Path(output_file_name).exists()) self.assertTrue(Path(output_file_name).exists())
self.assertTrue(Path(score_file_name).exists()) self.assertTrue(Path(score_file_name).exists())
os.remove(Path(output_file_name))
os.remove(Path(score_file_name))

View File

@ -14,13 +14,13 @@ def chunks(lst, n):
yield lst[i : i + n] yield lst[i : i + n]
def generate_translations(lns, output_file_path, batch_size, device): def generate_translations(lns, output_file_path, model_size, batch_size, device):
output_file = Path(output_file_path).open("w") output_file = Path(output_file_path).open("w")
model = T5ForConditionalGeneration.from_pretrained("t5-base") model = T5ForConditionalGeneration.from_pretrained(model_size)
model.to(device) model.to(device)
tokenizer = T5Tokenizer.from_pretrained("t5-base") tokenizer = T5Tokenizer.from_pretrained(model_size)
# update config with summarization specific params # update config with summarization specific params
task_specific_params = model.config.task_specific_params task_specific_params = model.config.task_specific_params
@ -52,6 +52,12 @@ def calculate_bleu_score(output_lns, refs_lns, score_path):
def run_generate(): def run_generate():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"model_size",
type=str,
help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
default="t5-base",
)
parser.add_argument( parser.add_argument(
"input_path", type=str, help="like wmt/newstest2013.en", "input_path", type=str, help="like wmt/newstest2013.en",
) )
@ -78,7 +84,7 @@ def run_generate():
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()] input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()]
generate_translations(input_lns, args.output_path, args.batch_size, args.device) generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device)
output_lns = [x.strip() for x in open(args.output_path).readlines()] output_lns = [x.strip() for x in open(args.output_path).readlines()]
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()] refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()]

View File

@ -1,5 +1,4 @@
import logging import logging
import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
@ -33,11 +32,19 @@ class TestT5Examples(unittest.TestCase):
with tmp_target.open("w") as f: with tmp_target.open("w") as f:
f.write("\n".join(translation)) f.write("\n".join(translation))
testargs = ["evaluate_wmt.py", str(tmp_source), output_file_name, str(tmp_target), score_file_name] output_file_name = Path(tempfile.gettempdir()) / "utest_output_trans.hypo"
score_file_name = Path(tempfile.gettempdir()) / "utest_score.hypo"
testargs = [
"evaluate_wmt.py",
"t5-small",
str(tmp_source),
str(output_file_name),
str(tmp_target),
str(score_file_name),
]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
self.assertTrue(Path(output_file_name).exists()) self.assertTrue(Path(output_file_name).exists())
self.assertTrue(Path(score_file_name).exists()) self.assertTrue(Path(score_file_name).exists())
os.remove(Path(output_file_name))
os.remove(Path(score_file_name))