Summarization Examples: add Bart CNN Evaluation (#3082)

* Rename and improve example

* Add test

* slightly faster test

* style

* This breaks remy prolly

* shorter test string

* no slow

* newdir structure

* New tree

* Style

* shorter

* docs

* clean

* Attempt future import

* more import hax
This commit is contained in:
Sam Shleifer 2020-03-03 15:29:59 -05:00 committed by GitHub
parent 5c5af879b6
commit 5b396457e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 148 additions and 14 deletions

View File

View File

@ -0,0 +1,45 @@
### Get the CNN/Daily Mail Data
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
```bash
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
### Usage
To create summaries for each article in dataset, run:
```bash
python evaluate_cnn.py <path_to_test.source> cnn_test_summaries.txt
```
the default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Where is the code?
The core model is in `src/transformers/modeling_bart.py`. This directory only contains examples.
### (WIP) Rouge Scores
### Stanford CoreNLP Setup
```
ptb_tokenize () {
cat $1 | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > $2
}
sudo apt install openjdk-8-jre-headless
sudo apt-get install ant
wget http://nlp.stanford.edu/software/stanford-corenlp-full-2018-10-05.zip
unzip stanford-corenlp-full-2018-10-05.zip
cd stanford-corenlp-full-2018-10-05
export CLASSPATH=stanford-corenlp-3.9.2.jar:stanford-corenlp-3.9.2-models.jar
```
### Rouge Setup
Install `files2rouge` following the instructions at [here](https://github.com/pltrdy/files2rouge).
I also needed to run `sudo apt-get install libxml-parser-perl`
```python
from files2rouge import files2rouge
from files2rouge import settings
files2rouge.run(<path_to_tokenized_hypo>,
<path_to_tokenized_target>,
saveto='rouge_output.txt')
```

View File

View File

@ -0,0 +1,60 @@
import argparse
from pathlib import Path
import torch
from tqdm import tqdm
from transformers import BartForMaskedLM, BartTokenizer
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
fout = Path(out_file).open("w")
model = BartForMaskedLM.from_pretrained("bart-large-cnn", output_past=True,)
tokenizer = BartTokenizer.from_pretrained("bart-large")
for batch in tqdm(list(chunks(lns, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
summaries = model.generate(
input_ids=dct["input_ids"].to(device),
attention_mask=dct["attention_mask"].to(device),
num_beams=4,
length_penalty=2.0,
max_length=140,
min_len=55,
no_repeat_ngram_size=3,
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
def _run_generate():
parser = argparse.ArgumentParser()
parser.add_argument(
"source_path", type=str, help="like cnn_dm/test.source",
)
parser.add_argument(
"output_path", type=str, help="where to save summaries",
)
parser.add_argument(
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
)
parser.add_argument(
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
)
args = parser.parse_args()
lns = [" " + x.rstrip() for x in open(args.source_path).readlines()]
generate_summaries(lns, args.output_path, batch_size=args.bs, device=args.device)
if __name__ == "__main__":
_run_generate()

View File

@ -0,0 +1,28 @@
import logging
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from .evaluate_cnn import _run_generate
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
class TestBartExamples(unittest.TestCase):
def test_bart_cnn_cli(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo"
with tmp.open("w") as f:
f.write("\n".join(articles))
testargs = ["evaluate_cnn.py", str(tmp), "output.txt"]
with patch.object(sys, "argv", testargs):
_run_generate()
self.assertTrue(Path("output.txt").exists())

View File

@ -15,7 +15,7 @@ pip install nltk py-rouge
cd examples/summarization
```
## Reproduce the authors' results on ROUGE
## Reproduce the authors' ROUGE score
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:

View File

@ -11,12 +11,13 @@ from tqdm import tqdm
from modeling_bertabs import BertAbs, build_predictor
from transformers import BertTokenizer
from utils_summarization import (
SummarizationDataset,
from .utils_summarization import (
CNNDMDataset,
build_mask,
compute_token_type_ids,
encode_for_summarization,
fit_to_block_size,
truncate_or_pad,
)
@ -194,7 +195,7 @@ def build_data_iterator(args, tokenizer):
def load_and_cache_examples(args, tokenizer):
dataset = SummarizationDataset(args.documents_dir)
dataset = CNNDMDataset(args.documents_dir)
return dataset
@ -211,7 +212,7 @@ def collate(data, tokenizer, block_size, device):
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
encoded_stories = torch.tensor(
[fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
[truncate_or_pad(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
)
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)

View File

@ -17,7 +17,7 @@ import unittest
import numpy as np
import torch
from utils_summarization import build_mask, compute_token_type_ids, fit_to_block_size, process_story
from .utils_summarization import build_mask, compute_token_type_ids, process_story, truncate_or_pad
class SummarizationDataProcessingTest(unittest.TestCase):
@ -28,19 +28,19 @@ class SummarizationDataProcessingTest(unittest.TestCase):
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence = [1, 2, 3, 4]
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_fit_to_block_sequence_fit_exactly(self):
""" Do nothing if the sequence is the right size. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_fit_to_block_sequence_too_big(self):
""" Truncate the sequence if it is too long. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_process_story_no_highlights(self):
""" Processing a story with no highlights returns an empty list for the summary.

View File

@ -10,7 +10,7 @@ from torch.utils.data import Dataset
# ------------
class SummarizationDataset(Dataset):
class CNNDMDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
The class will process the documents that are located in the specified
@ -62,11 +62,11 @@ class SummarizationDataset(Dataset):
def process_story(raw_story):
""" Extract the story and summary from a story file.
Attributes:
Arguments:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
IndexError: If the story is empty or contains no highlights.
"""
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
@ -107,7 +107,7 @@ def _add_missing_period(line):
# --------------------------
def fit_to_block_size(sequence, block_size, pad_token_id):
def truncate_or_pad(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding token to the right of the sequence.
"""