diff --git a/examples/requirements.txt b/examples/requirements.txt index 6a4126c9263..e1f1a2c114a 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -3,3 +3,5 @@ tensorboard scikit-learn seqeval psutil +rouge-score +tensorflow_datasets diff --git a/examples/summarization/t5/README.md b/examples/summarization/t5/README.md new file mode 100644 index 00000000000..222e5ff17cf --- /dev/null +++ b/examples/summarization/t5/README.md @@ -0,0 +1,25 @@ +***This script evaluates the the multitask pre-trained checkpoint for ``t5-large`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the CNN/Daily Mail test dataset. Please note that the results in the paper were attained using a model fine-tuned on summarization, so that results will be worse here by approx. 0.5 ROUGE points*** + +### Get the CNN Data +First, you need to download the CNN data. It's about ~400 MB and can be downloaded by +running + +```bash +python download_cnn_daily_mail.py cnn_articles_input_data.txt cnn_articles_reference_summaries.txt +``` + +You should confirm that each file has 11490 lines: + +```bash +wc -l cnn_articles_input_data.txt # should print 11490 +wc -l cnn_articles_reference_summaries.txt # should print 11490 +``` + +### Usage + +To create summaries for each article in dataset, run: +```bash +python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summaries.txt cnn_articles_reference_summaries.txt rouge_score.txt +``` +The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system. +The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``. diff --git a/examples/summarization/t5/__init__.py b/examples/summarization/t5/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/summarization/t5/download_cnn_daily_mail.py b/examples/summarization/t5/download_cnn_daily_mail.py new file mode 100644 index 00000000000..5089d9c1dc4 --- /dev/null +++ b/examples/summarization/t5/download_cnn_daily_mail.py @@ -0,0 +1,31 @@ +import argparse +from pathlib import Path + +import tensorflow_datasets as tfds + + +def main(input_path, reference_path, data_dir): + cnn_ds = tfds.load("cnn_dailymail", split="test", shuffle_files=False, data_dir=data_dir) + cnn_ds_iter = tfds.as_numpy(cnn_ds) + + test_articles_file = Path(input_path).open("w") + test_summaries_file = Path(reference_path).open("w") + + for example in cnn_ds_iter: + test_articles_file.write(example["article"].decode("utf-8") + "\n") + test_articles_file.flush() + test_summaries_file.write(example["highlights"].decode("utf-8").replace("\n", " ") + "\n") + test_summaries_file.flush() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input_path", type=str, help="where to save the articles input data") + parser.add_argument( + "reference_path", type=str, help="where to save the reference summaries", + ) + parser.add_argument( + "--data_dir", type=str, default="~/tensorflow_datasets", help="where to save the tensorflow datasets.", + ) + args = parser.parse_args() + main(args.input_path, args.reference_path, args.data_dir) diff --git a/examples/summarization/t5/evaluate_cnn.py b/examples/summarization/t5/evaluate_cnn.py new file mode 100644 index 00000000000..18750183ac8 --- /dev/null +++ b/examples/summarization/t5/evaluate_cnn.py @@ -0,0 +1,95 @@ +import argparse +from pathlib import Path + +import torch +from tqdm import tqdm + +from rouge_score import rouge_scorer, scoring +from transformers import T5ForConditionalGeneration, T5Tokenizer + + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def generate_summaries(lns, output_file_path, batch_size, device): + output_file = Path(output_file_path).open("w") + + model = T5ForConditionalGeneration.from_pretrained("t5-large") + model.to(device) + + tokenizer = T5Tokenizer.from_pretrained("t5-large") + + # update config with summarization specific params + task_specific_params = model.config.task_specific_params + if task_specific_params is not None: + model.config.update(task_specific_params.get("summarization", {})) + + for batch in tqdm(list(chunks(lns, batch_size))): + batch = [model.config.prefix + text for text in batch] + + dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True) + input_ids = dct["input_ids"].to(device) + attention_mask = dct["attention_mask"].to(device) + + summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask) + dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] + + for hypothesis in dec: + output_file.write(hypothesis + "\n") + output_file.flush() + + +def calculate_rouge(output_lns, reference_lns, score_path): + score_file = Path(score_path).open("w") + scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) + aggregator = scoring.BootstrapAggregator() + + for reference_ln, output_ln in zip(reference_lns, output_lns): + scores = scorer.score(reference_ln, output_ln) + aggregator.add_scores(scores) + + result = aggregator.aggregate() + score_file.write( + "ROUGE_1: \n{} \n\n ROUGE_2: \n{} \n\n ROUGE_L: \n{} \n\n".format( + result["rouge1"], result["rouge2"], result["rougeL"] + ) + ) + + +def run_generate(): + parser = argparse.ArgumentParser() + parser.add_argument( + "input_path", type=str, help="like cnn_dm/test_articles_input.txt", + ) + parser.add_argument( + "output_path", type=str, help="where to save summaries", + ) + parser.add_argument("reference_path", type=str, help="like cnn_dm/test_reference_summaries.txt") + parser.add_argument( + "score_path", type=str, help="where to save the rouge score", + ) + parser.add_argument( + "--batch_size", type=int, default=8, required=False, help="batch size: how many to summarize at a time", + ) + parser.add_argument( + "--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.", + ) + + args = parser.parse_args() + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + + source_lns = [x.rstrip() for x in open(args.input_path).readlines()] + + generate_summaries(source_lns, args.output_path, args.batch_size, args.device) + + output_lns = [x.rstrip() for x in open(args.output_path).readlines()] + reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] + + calculate_rouge(output_lns, reference_lns, args.score_path) + + +if __name__ == "__main__": + run_generate() diff --git a/examples/summarization/t5/test_t5_examples.py b/examples/summarization/t5/test_t5_examples.py new file mode 100644 index 00000000000..eb24c31c89c --- /dev/null +++ b/examples/summarization/t5/test_t5_examples.py @@ -0,0 +1,29 @@ +import logging +import sys +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from .evaluate_cnn import run_generate + + +articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() + + +class TestT5Examples(unittest.TestCase): + def test_t5_cli(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo" + with tmp.open("w") as f: + f.write("\n".join(articles)) + testargs = ["evaluate_cnn.py", str(tmp), "output.txt", str(tmp), "score.txt"] + with patch.object(sys, "argv", testargs): + run_generate() + self.assertTrue(Path("output.txt").exists()) + self.assertTrue(Path("score.txt").exists())