transformers/examples/summarization/test_summarization_examples.py
2020-06-22 20:40:10 -04:00

268 lines
9.5 KiB
Python

import argparse
import logging
import os
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import torch
from torch.utils.data import DataLoader
from transformers import BartTokenizer
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .run_eval import generate_summaries, run_generate
from .utils import SummarizationDataset, lmap, pickle_load
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
FP16_EVER = False
CHEAP_ARGS = {
"logger": "default",
"num_workers": 2,
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
"tgt_suffix": "",
"resume_from_checkpoint": None,
"sortish_sampler": True,
"student_decoder_layers": 1,
"val_check_interval": 1.0,
"output_dir": "",
"fp16": False,
"no_teacher": False,
"fp16_opt_level": "O1",
"gpus": 1 if torch.cuda.is_available() else 0,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"gradient_accumulation_steps": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
"model_type": "bart",
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "facebook/bart-large",
"cache_dir": "",
"do_lower_case": False,
"learning_rate": 3e-05,
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"num_train_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
"val_max_target_length": 12,
"test_max_target_length": 12,
"fast_dev_run": False,
"no_cache": False,
"n_train": -1,
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"alpha_loss_encoder": 0.0,
"freeze_encoder": False,
"auto_scale_batch_size": False,
}
def _dump_articles(path: Path, articles: list):
with path.open("w") as f:
f.write("\n".join(articles))
MSG = "T5 is broken at the moment"
T5_TINY = "patrickvonplaten/t5-tiny-random"
def make_test_data_dir():
tmp_dir = Path(tempfile.gettempdir())
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
for split in ["train", "val", "test"]:
_dump_articles((tmp_dir / f"{split}.source"), articles)
_dump_articles((tmp_dir / f"{split}.target"), summaries)
return tmp_dir
class TestSummarizationDistiller(unittest.TestCase):
@classmethod
def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
def test_bdc_multigpu(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
no_teacher=True,
freeze_encoder=True,
gpus=2,
sortish_sampler=False,
fp16_opt_level="O1",
fp16=FP16_EVER,
)
self._bart_distiller_cli(updates)
def test_bdc_t5_train(self):
updates = dict(
fp16=FP16_EVER,
gpus=1 if torch.cuda.is_available() else 0,
model_type="t5",
model_name_or_path=T5_TINY,
do_train=True,
do_predict=True,
tokenizer_name=T5_TINY,
no_teacher=True,
alpha_hid=2.0,
)
self._bart_distiller_cli(updates)
def test_bdc_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,)
self._bart_distiller_cli(updates)
def test_bdc_yes_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1,)
self._bart_distiller_cli(updates)
def test_bdc_checkpointing(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
)
model = self._bart_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), len(ckpts))
new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(new_transformer_ckpts), 1)
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
out_path = tempfile.mktemp()
generate_summaries(examples, out_path, new_transformer_ckpts[0].parent)
self.assertTrue(Path(out_path).exists())
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
def _bart_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
train_batch_size=1,
eval_batch_size=2,
num_train_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
gpus=1 if torch.cuda.is_available() else 0,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
alpha_encoder_loss=0.4,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d))
if not check_contents:
return model
contents = os.listdir(output_dir)
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
contents = {os.path.basename(p) for p in contents}
self.assertIn(ckpt_name, contents)
self.assertIn("metrics.pkl", contents)
self.assertIn("test_generations.txt", contents)
self.assertIn("val_generations_00001.txt", contents)
self.assertIn("val_results_00001.txt", contents)
self.assertIn("test_results.txt", contents)
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here
return model
class TestBartExamples(unittest.TestCase):
@classmethod
def setUpClass(cls):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
def test_bart_cnn_cli(self):
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(tmp, articles)
testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path(output_file_name).exists())
os.remove(Path(output_file_name))
def test_t5_run_sum_cli(self):
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_name_or_path=T5_TINY,
tokenizer_name=None, # T5_TINY,
train_batch_size=2,
eval_batch_size=2,
gpus=0,
output_dir=output_dir,
do_predict=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
main(args)
def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir())
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
_dump_articles((tmp_dir / "train.source"), articles)
_dump_articles((tmp_dir / "train.target"), summaries)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
max_len_source = max(len(tokenizer.encode(a)) for a in articles)
max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape)
# show that articles were trimmed.
self.assertEqual(batch["input_ids"].shape[1], max_len_source)
self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly
# show that targets were truncated
self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated
self.assertGreater(max_len_target, trunc_target) # Truncated
def list_to_text_file(lst, path):
dest = Path(path)
dest.open("w+").writelines(lst)