[cleanup] examples test_run_squad uses tiny model (#5059)

This commit is contained in:
Sam Shleifer 2020-06-16 14:06:45 -04:00 committed by GitHub
parent 439aa1d6e9
commit c3e607496c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -55,7 +55,7 @@ class ExamplesTests(unittest.TestCase):
testargs = """
run_glue.py
--model_name_or_path bert-base-uncased
--model_name_or_path distilbert-base-uncased
--data_dir ./tests/fixtures/tests_samples/MRPC/
--task_name mrpc
--do_train
@ -79,6 +79,7 @@ class ExamplesTests(unittest.TestCase):
def test_run_language_modeling(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
# TODO: switch to smaller model like sshleifer/tiny-distilroberta-base
testargs = """
run_language_modeling.py
@ -105,10 +106,9 @@ class ExamplesTests(unittest.TestCase):
testargs = """
run_squad.py
--model_type=bert
--model_name_or_path=bert-base-uncased
--model_type=distilbert
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
--data_dir=./tests/fixtures/tests_samples/SQUAD
--model_name=bert-base-uncased
--output_dir=./tests/fixtures/tests_samples/temp_dir
--max_steps=10
--warmup_steps=2
@ -123,15 +123,15 @@ class ExamplesTests(unittest.TestCase):
""".split()
with patch.object(sys, "argv", testargs):
result = run_squad.main()
self.assertGreaterEqual(result["f1"], 30)
self.assertGreaterEqual(result["exact"], 30)
self.assertGreaterEqual(result["f1"], 25)
self.assertGreaterEqual(result["exact"], 21)
def test_generation(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
model_type, model_name = ("--model_type=gpt2", "--model_name_or_path=sshleifer/tiny-gpt2")
with patch.object(sys, "argv", testargs + [model_type, model_name]):
result = run_generation.main()
self.assertGreaterEqual(len(result[0]), 10)