Add t5 summarization example (#3411)

* rebase to master

* change tf to pytorch

* change to pytorch

* small fix

* renaming

* add gpu training possibility

* renaming

* improve README

* incoorporate collins feedback

* better Readme

* better README.md
This commit is contained in:
Patrick von Platen 2020-03-26 18:17:55 +01:00 committed by GitHub
parent 1a6c546c6f
commit e703e923ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 182 additions and 0 deletions

View File

@ -3,3 +3,5 @@ tensorboard
scikit-learn
seqeval
psutil
rouge-score
tensorflow_datasets

View File

@ -0,0 +1,25 @@
***This script evaluates the the multitask pre-trained checkpoint for ``t5-large`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the CNN/Daily Mail test dataset. Please note that the results in the paper were attained using a model fine-tuned on summarization, so that results will be worse here by approx. 0.5 ROUGE points***
### Get the CNN Data
First, you need to download the CNN data. It's about ~400 MB and can be downloaded by
running
```bash
python download_cnn_daily_mail.py cnn_articles_input_data.txt cnn_articles_reference_summaries.txt
```
You should confirm that each file has 11490 lines:
```bash
wc -l cnn_articles_input_data.txt # should print 11490
wc -l cnn_articles_reference_summaries.txt # should print 11490
```
### Usage
To create summaries for each article in dataset, run:
```bash
python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summaries.txt cnn_articles_reference_summaries.txt rouge_score.txt
```
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``.

View File

View File

@ -0,0 +1,31 @@
import argparse
from pathlib import Path
import tensorflow_datasets as tfds
def main(input_path, reference_path, data_dir):
cnn_ds = tfds.load("cnn_dailymail", split="test", shuffle_files=False, data_dir=data_dir)
cnn_ds_iter = tfds.as_numpy(cnn_ds)
test_articles_file = Path(input_path).open("w")
test_summaries_file = Path(reference_path).open("w")
for example in cnn_ds_iter:
test_articles_file.write(example["article"].decode("utf-8") + "\n")
test_articles_file.flush()
test_summaries_file.write(example["highlights"].decode("utf-8").replace("\n", " ") + "\n")
test_summaries_file.flush()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="where to save the articles input data")
parser.add_argument(
"reference_path", type=str, help="where to save the reference summaries",
)
parser.add_argument(
"--data_dir", type=str, default="~/tensorflow_datasets", help="where to save the tensorflow datasets.",
)
args = parser.parse_args()
main(args.input_path, args.reference_path, args.data_dir)

View File

@ -0,0 +1,95 @@
import argparse
from pathlib import Path
import torch
from tqdm import tqdm
from rouge_score import rouge_scorer, scoring
from transformers import T5ForConditionalGeneration, T5Tokenizer
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate_summaries(lns, output_file_path, batch_size, device):
output_file = Path(output_file_path).open("w")
model = T5ForConditionalGeneration.from_pretrained("t5-large")
model.to(device)
tokenizer = T5Tokenizer.from_pretrained("t5-large")
# update config with summarization specific params
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
model.config.update(task_specific_params.get("summarization", {}))
for batch in tqdm(list(chunks(lns, batch_size))):
batch = [model.config.prefix + text for text in batch]
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True)
input_ids = dct["input_ids"].to(device)
attention_mask = dct["attention_mask"].to(device)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec:
output_file.write(hypothesis + "\n")
output_file.flush()
def calculate_rouge(output_lns, reference_lns, score_path):
score_file = Path(score_path).open("w")
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
aggregator = scoring.BootstrapAggregator()
for reference_ln, output_ln in zip(reference_lns, output_lns):
scores = scorer.score(reference_ln, output_ln)
aggregator.add_scores(scores)
result = aggregator.aggregate()
score_file.write(
"ROUGE_1: \n{} \n\n ROUGE_2: \n{} \n\n ROUGE_L: \n{} \n\n".format(
result["rouge1"], result["rouge2"], result["rougeL"]
)
)
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument(
"input_path", type=str, help="like cnn_dm/test_articles_input.txt",
)
parser.add_argument(
"output_path", type=str, help="where to save summaries",
)
parser.add_argument("reference_path", type=str, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument(
"score_path", type=str, help="where to save the rouge score",
)
parser.add_argument(
"--batch_size", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
)
parser.add_argument(
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
)
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
source_lns = [x.rstrip() for x in open(args.input_path).readlines()]
generate_summaries(source_lns, args.output_path, args.batch_size, args.device)
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
calculate_rouge(output_lns, reference_lns, args.score_path)
if __name__ == "__main__":
run_generate()

View File

@ -0,0 +1,29 @@
import logging
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from .evaluate_cnn import run_generate
articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
class TestT5Examples(unittest.TestCase):
def test_t5_cli(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo"
with tmp.open("w") as f:
f.write("\n".join(articles))
testargs = ["evaluate_cnn.py", str(tmp), "output.txt", str(tmp), "score.txt"]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path("output.txt").exists())
self.assertTrue(Path("score.txt").exists())