mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[test] replace capsys with the more refined CaptureStderr/CaptureStdout (#6422)
* replace capsys with the more refined CaptureStderr/CaptureStdout * Update examples/seq2seq/test_seq2seq_examples.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
parent
ac5bcf236e
commit
87b359439f
@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
import lightning_base
|
import lightning_base
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
from transformers.testing_utils import require_multigpu
|
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu
|
||||||
|
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
from .finetune import SummarizationModule, main
|
from .finetune import SummarizationModule, main
|
||||||
@ -329,7 +329,7 @@ def test_finetune_extra_model_args():
|
|||||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_lr_shedulers(capsys):
|
def test_finetune_lr_schedulers():
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
|
||||||
task = "summarization"
|
task = "summarization"
|
||||||
@ -361,23 +361,23 @@ def test_finetune_lr_shedulers(capsys):
|
|||||||
|
|
||||||
# --help test
|
# --help test
|
||||||
with pytest.raises(SystemExit) as excinfo:
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
args = parser.parse_args(args)
|
with CaptureStdout() as cs:
|
||||||
|
args = parser.parse_args(args)
|
||||||
assert False, "--help is expected to sys.exit"
|
assert False, "--help is expected to sys.exit"
|
||||||
assert excinfo.type == SystemExit
|
assert excinfo.type == SystemExit
|
||||||
captured = capsys.readouterr()
|
|
||||||
expected = lightning_base.arg_to_scheduler_metavar
|
expected = lightning_base.arg_to_scheduler_metavar
|
||||||
assert expected in captured.out, "--help is expected to list the supported schedulers"
|
assert expected in cs.out, "--help is expected to list the supported schedulers"
|
||||||
|
|
||||||
# --lr_scheduler=non_existing_scheduler test
|
# --lr_scheduler=non_existing_scheduler test
|
||||||
unsupported_param = "non_existing_scheduler"
|
unsupported_param = "non_existing_scheduler"
|
||||||
args = {f"--lr_scheduler={unsupported_param}"}
|
args = {f"--lr_scheduler={unsupported_param}"}
|
||||||
with pytest.raises(SystemExit) as excinfo:
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
args = parser.parse_args(args)
|
with CaptureStderr() as cs:
|
||||||
|
args = parser.parse_args(args)
|
||||||
assert False, "invalid argument is expected to sys.exit"
|
assert False, "invalid argument is expected to sys.exit"
|
||||||
assert excinfo.type == SystemExit
|
assert excinfo.type == SystemExit
|
||||||
captured = capsys.readouterr()
|
|
||||||
expected = f"invalid choice: '{unsupported_param}'"
|
expected = f"invalid choice: '{unsupported_param}'"
|
||||||
assert expected in captured.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
||||||
|
|
||||||
# --lr_scheduler=existing_scheduler test
|
# --lr_scheduler=existing_scheduler test
|
||||||
supported_param = "cosine"
|
supported_param = "cosine"
|
||||||
|
Loading…
Reference in New Issue
Block a user