mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00

* Add BERT Loses Patience (Patience-based Early Exit) * update model archive * update format * sort import * flake8 * Add results * full results * align the table * refactor to inherit * default per gpu eval = 1 * Formatting * Formatting * isort * modify readme * Add check * Fix format * Fix format * Doc strings * ALBERT & BERT for sequence classification don't inherit from the original anymore * Remove incorrect comments * Remove incorrect comments * Remove incorrect comments * Sync up with new code * Sync up with new code * Add a test * Add a test * Add a test * Add a test * Add a test * Add a test * Finishing up!
49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
import argparse
|
|
import logging
|
|
import sys
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import run_glue_with_pabee
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
def get_setup_file():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-f")
|
|
args = parser.parse_args()
|
|
return args.f
|
|
|
|
|
|
class PabeeTests(unittest.TestCase):
|
|
def test_run_glue(self):
|
|
stream_handler = logging.StreamHandler(sys.stdout)
|
|
logger.addHandler(stream_handler)
|
|
|
|
testargs = """
|
|
run_glue_with_pabee.py
|
|
--model_type albert
|
|
--model_name_or_path albert-base-v2
|
|
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
|
--task_name mrpc
|
|
--do_train
|
|
--do_eval
|
|
--output_dir ./tests/fixtures/tests_samples/temp_dir
|
|
--per_gpu_train_batch_size=2
|
|
--per_gpu_eval_batch_size=1
|
|
--learning_rate=2e-5
|
|
--max_steps=50
|
|
--warmup_steps=2
|
|
--overwrite_output_dir
|
|
--seed=42
|
|
--max_seq_length=128
|
|
""".split()
|
|
with patch.object(sys, "argv", testargs):
|
|
result = run_glue_with_pabee.main()
|
|
for value in result.values():
|
|
self.assertGreaterEqual(value, 0.75)
|