import logging import sys import tempfile import unittest from pathlib import Path from unittest.mock import patch from .evaluate_cnn import run_generate articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() class TestBartExamples(unittest.TestCase): def test_bart_cnn_cli(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" with tmp.open("w") as f: f.write("\n".join(articles)) output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo" testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"] with patch.object(sys, "argv", testargs): run_generate() self.assertTrue(Path(output_file_name).exists())