mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[cleanup] examples test_run_squad uses tiny model (#5059)
This commit is contained in:
parent
439aa1d6e9
commit
c3e607496c
@ -55,7 +55,7 @@ class ExamplesTests(unittest.TestCase):
|
||||
|
||||
testargs = """
|
||||
run_glue.py
|
||||
--model_name_or_path bert-base-uncased
|
||||
--model_name_or_path distilbert-base-uncased
|
||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||
--task_name mrpc
|
||||
--do_train
|
||||
@ -79,6 +79,7 @@ class ExamplesTests(unittest.TestCase):
|
||||
def test_run_language_modeling(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
# TODO: switch to smaller model like sshleifer/tiny-distilroberta-base
|
||||
|
||||
testargs = """
|
||||
run_language_modeling.py
|
||||
@ -105,10 +106,9 @@ class ExamplesTests(unittest.TestCase):
|
||||
|
||||
testargs = """
|
||||
run_squad.py
|
||||
--model_type=bert
|
||||
--model_name_or_path=bert-base-uncased
|
||||
--model_type=distilbert
|
||||
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
|
||||
--data_dir=./tests/fixtures/tests_samples/SQUAD
|
||||
--model_name=bert-base-uncased
|
||||
--output_dir=./tests/fixtures/tests_samples/temp_dir
|
||||
--max_steps=10
|
||||
--warmup_steps=2
|
||||
@ -123,15 +123,15 @@ class ExamplesTests(unittest.TestCase):
|
||||
""".split()
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_squad.main()
|
||||
self.assertGreaterEqual(result["f1"], 30)
|
||||
self.assertGreaterEqual(result["exact"], 30)
|
||||
self.assertGreaterEqual(result["f1"], 25)
|
||||
self.assertGreaterEqual(result["exact"], 21)
|
||||
|
||||
def test_generation(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
||||
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
|
||||
model_type, model_name = ("--model_type=gpt2", "--model_name_or_path=sshleifer/tiny-gpt2")
|
||||
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||
result = run_generation.main()
|
||||
self.assertGreaterEqual(len(result[0]), 10)
|
||||
|
Loading…
Reference in New Issue
Block a user