import argparse import logging import os import sys import tempfile import unittest from pathlib import Path from unittest.mock import patch import torch from torch.utils.data import DataLoader from transformers import BartTokenizer from .distillation import distill_main, evaluate_checkpoint from .finetune import main from .run_eval import generate_summaries, run_generate from .utils import SummarizationDataset, lmap, pickle_load logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() FP16_EVER = False CHEAP_ARGS = { "logger": "default", "num_workers": 2, "alpha_hid": 0, "freeze_embeds": True, "enc_only": False, "tgt_suffix": "", "resume_from_checkpoint": None, "sortish_sampler": True, "student_decoder_layers": 1, "val_check_interval": 1.0, "output_dir": "", "fp16": False, "no_teacher": False, "fp16_opt_level": "O1", "gpus": 1 if torch.cuda.is_available() else 0, "n_tpu_cores": 0, "max_grad_norm": 1.0, "do_train": True, "do_predict": True, "gradient_accumulation_steps": 1, "server_ip": "", "server_port": "", "seed": 42, "model_type": "bart", "model_name_or_path": "sshleifer/bart-tiny-random", "config_name": "", "tokenizer_name": "facebook/bart-large", "cache_dir": "", "do_lower_case": False, "learning_rate": 3e-05, "weight_decay": 0.0, "adam_epsilon": 1e-08, "warmup_steps": 0, "num_train_epochs": 1, "train_batch_size": 2, "eval_batch_size": 2, "max_source_length": 12, "max_target_length": 12, "val_max_target_length": 12, "test_max_target_length": 12, "fast_dev_run": False, "no_cache": False, "n_train": -1, "n_val": -1, "n_test": -1, "student_encoder_layers": 1, "alpha_loss_encoder": 0.0, "freeze_encoder": False, "auto_scale_batch_size": False, } def _dump_articles(path: Path, articles: list): with path.open("w") as f: f.write("\n".join(articles)) MSG = "T5 is broken at the moment" T5_TINY = "patrickvonplaten/t5-tiny-random" def make_test_data_dir(): tmp_dir = Path(tempfile.gettempdir()) articles = [" Sam ate lunch today", "Sams lunch ingredients"] summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] for split in ["train", "val", "test"]: _dump_articles((tmp_dir / f"{split}.source"), articles) _dump_articles((tmp_dir / f"{split}.target"), summaries) return tmp_dir class TestSummarizationDistiller(unittest.TestCase): @classmethod def setUpClass(cls): logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks return cls @unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test") def test_bdc_multigpu(self): updates = dict( student_encoder_layers=2, student_decoder_layers=1, no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False, fp16_opt_level="O1", fp16=FP16_EVER, ) self._bart_distiller_cli(updates) def test_bdc_t5_train(self): updates = dict( fp16=FP16_EVER, gpus=1 if torch.cuda.is_available() else 0, model_type="t5", model_name_or_path=T5_TINY, do_train=True, do_predict=True, tokenizer_name=T5_TINY, no_teacher=True, alpha_hid=2.0, ) self._bart_distiller_cli(updates) def test_bdc_no_teacher(self): updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,) self._bart_distiller_cli(updates) def test_bdc_yes_teacher(self): updates = dict(student_encoder_layers=2, student_decoder_layers=1,) self._bart_distiller_cli(updates) def test_bdc_checkpointing(self): updates = dict( student_encoder_layers=2, student_decoder_layers=1, num_train_epochs=4, val_check_interval=0.25, alpha_hid=2.0, ) model = self._bart_distiller_cli(updates, check_contents=False) ckpts = list(Path(model.output_dir).glob("*.ckpt")) self.assertEqual(1, len(ckpts)) transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) self.assertEqual(len(transformer_ckpts), len(ckpts)) new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) self.assertEqual(len(new_transformer_ckpts), 1) examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines()) out_path = tempfile.mktemp() generate_summaries(examples, out_path, new_transformer_ckpts[0].parent) self.assertTrue(Path(out_path).exists()) evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) def _bart_distiller_cli(self, updates, check_contents=True): default_updates = dict( train_batch_size=1, eval_batch_size=2, num_train_epochs=2, alpha_mlm=0.2, alpha_ce=0.8, do_predict=True, gpus=1 if torch.cuda.is_available() else 0, model_name_or_path="sshleifer/tinier_bart", teacher=CHEAP_ARGS["model_name_or_path"], val_check_interval=0.5, alpha_encoder_loss=0.4, ) default_updates.update(updates) args_d: dict = CHEAP_ARGS.copy() tmp_dir = make_test_data_dir() output_dir = tempfile.mkdtemp(prefix="output_") args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates) model = distill_main(argparse.Namespace(**args_d)) if not check_contents: return model contents = os.listdir(output_dir) ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt" contents = {os.path.basename(p) for p in contents} self.assertIn(ckpt_name, contents) self.assertIn("metrics.pkl", contents) self.assertIn("test_generations.txt", contents) self.assertIn("val_generations_00001.txt", contents) self.assertIn("val_results_00001.txt", contents) self.assertIn("test_results.txt", contents) metrics = pickle_load(Path(output_dir) / "metrics.pkl") desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1) self.assertEqual(len(metrics["val"]), desired_n_evals) self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here return model class TestBartExamples(unittest.TestCase): @classmethod def setUpClass(cls): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks return cls def test_bart_cnn_cli(self): tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo" articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] _dump_articles(tmp, articles) testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"] with patch.object(sys, "argv", testargs): run_generate() self.assertTrue(Path(output_file_name).exists()) os.remove(Path(output_file_name)) def test_t5_run_sum_cli(self): args_d: dict = CHEAP_ARGS.copy() tmp_dir = make_test_data_dir() output_dir = tempfile.mkdtemp(prefix="output_") args_d.update( data_dir=tmp_dir, model_name_or_path=T5_TINY, tokenizer_name=None, # T5_TINY, train_batch_size=2, eval_batch_size=2, gpus=0, output_dir=output_dir, do_predict=True, ) assert "n_train" in args_d args = argparse.Namespace(**args_d) main(args) def test_bart_summarization_dataset(self): tmp_dir = Path(tempfile.gettempdir()) articles = [" Sam ate lunch today", "Sams lunch ingredients"] summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] _dump_articles((tmp_dir / "train.source"), articles) _dump_articles((tmp_dir / "train.target"), summaries) tokenizer = BartTokenizer.from_pretrained("facebook/bart-large") max_len_source = max(len(tokenizer.encode(a)) for a in articles) max_len_target = max(len(tokenizer.encode(a)) for a in summaries) trunc_target = 4 train_dataset = SummarizationDataset( tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target, ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape) # show that articles were trimmed. self.assertEqual(batch["input_ids"].shape[1], max_len_source) self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly # show that targets were truncated self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated self.assertGreater(max_len_target, trunc_target) # Truncated def list_to_text_file(lst, path): dest = Path(path) dest.open("w+").writelines(lst)