mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[examples] consolidate summarization examples (#4837)
This commit is contained in:
parent
9f5d5a531d
commit
02e5f79662
@ -1,4 +1,7 @@
|
||||
### Get Preprocessed CNN Data
|
||||
### Get CNN Data
|
||||
Both types of models do require CNN data and follow different procedures of obtaining so.
|
||||
|
||||
#### For BART models
|
||||
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
|
||||
|
||||
```bash
|
||||
@ -6,25 +9,43 @@ wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
|
||||
tar -xzvf cnn_dm.tgz
|
||||
```
|
||||
|
||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||
|
||||
#### For T5 models
|
||||
First, you need to download the CNN data. It's about ~400 MB and can be downloaded by
|
||||
running
|
||||
|
||||
```bash
|
||||
python download_cnn_daily_mail.py cnn_articles_input_data.txt cnn_articles_reference_summaries.txt
|
||||
```
|
||||
|
||||
You should confirm that each file has 11490 lines:
|
||||
|
||||
```bash
|
||||
wc -l cnn_articles_input_data.txt # should print 11490
|
||||
wc -l cnn_articles_reference_summaries.txt # should print 11490
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
To create summaries for each article in dataset, run:
|
||||
```bash
|
||||
python evaluate_cnn.py <path_to_test.source> cnn_test_summaries.txt
|
||||
python evaluate_cnn.py <path_to_test.source> test_generations.txt <model-name>
|
||||
```
|
||||
the default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
|
||||
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
|
||||
### Training
|
||||
Run/modify `run_train.sh`
|
||||
|
||||
### Where is the code?
|
||||
The core model is in `src/transformers/modeling_bart.py`. This directory only contains examples.
|
||||
Run/modify `finetune_bart.sh` or `finetune_t5.sh`
|
||||
|
||||
## (WIP) Rouge Scores
|
||||
|
||||
To create summaries for each article in dataset and also calculate rouge scores run:
|
||||
```bash
|
||||
python evaluate_cnn.py <path_to_test.source> test_generations.txt <model-name> --reference_path <path_to_correct_summaries> --score_path <path_to_save_rouge_scores>
|
||||
```
|
||||
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``<path_to_save_rouge_scores>``.
|
||||
|
||||
### Stanford CoreNLP Setup
|
||||
```
|
||||
ptb_tokenize () {
|
@ -1,71 +0,0 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def generate_summaries(
|
||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
|
||||
):
|
||||
fout = Path(out_file).open("w")
|
||||
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
|
||||
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
||||
|
||||
max_length = 140
|
||||
min_length = 55
|
||||
|
||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
|
||||
summaries = model.generate(
|
||||
input_ids=dct["input_ids"].to(device),
|
||||
attention_mask=dct["attention_mask"].to(device),
|
||||
num_beams=4,
|
||||
length_penalty=2.0,
|
||||
max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length
|
||||
min_length=min_length + 1, # +1 from original because we start at step=1
|
||||
no_repeat_ngram_size=3,
|
||||
early_stopping=True,
|
||||
decoder_start_token_id=model.config.eos_token_id,
|
||||
)
|
||||
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
|
||||
for hypothesis in dec:
|
||||
fout.write(hypothesis + "\n")
|
||||
fout.flush()
|
||||
|
||||
|
||||
def run_generate():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"source_path", type=str, help="like cnn_dm/test.source",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path", type=str, help="where to save summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_name", type=str, default="facebook/bart-large-cnn", help="like bart-large-cnn",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
examples = [" " + x.rstrip() for x in open(args.source_path).readlines()]
|
||||
generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_generate()
|
@ -1,3 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
@ -8,8 +9,8 @@ def main(input_path, reference_path, data_dir):
|
||||
cnn_ds = tfds.load("cnn_dailymail", split="test", shuffle_files=False, data_dir=data_dir)
|
||||
cnn_ds_iter = tfds.as_numpy(cnn_ds)
|
||||
|
||||
test_articles_file = Path(input_path).open("w")
|
||||
test_summaries_file = Path(reference_path).open("w")
|
||||
test_articles_file = Path(input_path).open("w", encoding="utf-8")
|
||||
test_summaries_file = Path(reference_path).open("w", encoding="utf-8")
|
||||
|
||||
for example in cnn_ds_iter:
|
||||
test_articles_file.write(example["article"].decode("utf-8") + "\n")
|
100
examples/summarization/evaluate_cnn.py
Normal file
100
examples/summarization/evaluate_cnn.py
Normal file
@ -0,0 +1,100 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoModelWithLMHead, AutoTokenizer
|
||||
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def generate_summaries(
|
||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
|
||||
):
|
||||
fout = Path(out_file).open("w", encoding="utf-8")
|
||||
model = AutoModelWithLMHead.from_pretrained(model_name).to(device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# update config with summarization specific params
|
||||
task_specific_params = model.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
model.config.update(task_specific_params.get("summarization", {}))
|
||||
|
||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||
if "t5" in model_name:
|
||||
batch = [model.config.prefix + text for text in batch]
|
||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
|
||||
device
|
||||
)
|
||||
summaries = model.generate(**dct)
|
||||
|
||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for hypothesis in dec:
|
||||
fout.write(hypothesis + "\n")
|
||||
fout.flush()
|
||||
|
||||
|
||||
def calculate_rouge(output_lns, reference_lns, score_path):
|
||||
score_file = Path(score_path).open("w")
|
||||
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
|
||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||
scores = scorer.score(reference_ln, output_ln)
|
||||
aggregator.add_scores(scores)
|
||||
|
||||
result = aggregator.aggregate()
|
||||
score_file.write(
|
||||
"ROUGE_1: \n{} \n\n ROUGE_2: \n{} \n\n ROUGE_L: \n{} \n\n".format(
|
||||
result["rouge1"], result["rouge2"], result["rougeL"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_generate():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"input_path", type=str, help="like cnn_dm/test.source or cnn_dm/test_articles_input.txt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path", type=str, help="where to save summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_name",
|
||||
type=str,
|
||||
default="facebook/bart-large-cnn",
|
||||
help="like bart-large-cnn,'t5-small', 't5-base', 't5-large', 't5-3b', 't5-11b",
|
||||
)
|
||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
||||
parser.add_argument(
|
||||
"--score_path", type=str, required=False, help="where to save the rouge score",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||
|
||||
generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
|
||||
if args.score_path is not None:
|
||||
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
||||
|
||||
calculate_rouge(output_lns, reference_lns, args.score_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_generate()
|
0
examples/summarization/bart/run_train.sh → examples/summarization/finetune_bart.sh
Executable file → Normal file
0
examples/summarization/bart/run_train.sh → examples/summarization/finetune_bart.sh
Executable file → Normal file
0
examples/summarization/bart/run_train_tiny.sh → examples/summarization/finetune_bart_tiny.sh
Executable file → Normal file
0
examples/summarization/bart/run_train_tiny.sh → examples/summarization/finetune_bart_tiny.sh
Executable file → Normal file
19
examples/summarization/finetune_t5.sh
Normal file
19
examples/summarization/finetune_t5.sh
Normal file
@ -0,0 +1,19 @@
|
||||
export OUTPUT_DIR_NAME=t5
|
||||
export CURRENT_DIR=${PWD}
|
||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
||||
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../../":"${PYTHONPATH}"
|
||||
|
||||
python finetune.py \
|
||||
--data_dir=./cnn-dailymail/cnn_dm \
|
||||
--model_name_or_path=t5-large \
|
||||
--model_type=t5
|
||||
--learning_rate=3e-5 \
|
||||
--train_batch_size=4 \
|
||||
--eval_batch_size=4 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_train $@
|
@ -1,29 +0,0 @@
|
||||
***This script evaluates the the multitask pre-trained checkpoint for ``t5-base`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the CNN/Daily Mail test dataset. Please note that the results in the paper were attained using a model fine-tuned on summarization, so that results will be worse here by approx. 0.5 ROUGE points***
|
||||
|
||||
### Get the CNN Data
|
||||
First, you need to download the CNN data. It's about ~400 MB and can be downloaded by
|
||||
running
|
||||
|
||||
```bash
|
||||
python download_cnn_daily_mail.py cnn_articles_input_data.txt cnn_articles_reference_summaries.txt
|
||||
```
|
||||
|
||||
You should confirm that each file has 11490 lines:
|
||||
|
||||
```bash
|
||||
wc -l cnn_articles_input_data.txt # should print 11490
|
||||
wc -l cnn_articles_reference_summaries.txt # should print 11490
|
||||
```
|
||||
|
||||
### Generating Summaries
|
||||
|
||||
To create summaries for each article in dataset, run:
|
||||
```bash
|
||||
python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summaries.txt cnn_articles_reference_summaries.txt rouge_score.txt
|
||||
```
|
||||
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``.
|
||||
|
||||
|
||||
### Finetuning
|
||||
Pass model_type=t5 and model `examples/summarization/bart/finetune.py`
|
@ -1,101 +0,0 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def generate_summaries(lns, output_file_path, model_size, batch_size, device):
|
||||
output_file = Path(output_file_path).open("w")
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_size)
|
||||
model.to(device)
|
||||
|
||||
tokenizer = T5Tokenizer.from_pretrained(model_size)
|
||||
|
||||
# update config with summarization specific params
|
||||
task_specific_params = model.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
model.config.update(task_specific_params.get("summarization", {}))
|
||||
|
||||
for batch in tqdm(list(chunks(lns, batch_size))):
|
||||
batch = [model.config.prefix + text for text in batch]
|
||||
|
||||
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True)
|
||||
input_ids = dct["input_ids"].to(device)
|
||||
attention_mask = dct["attention_mask"].to(device)
|
||||
|
||||
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
|
||||
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
|
||||
|
||||
for hypothesis in dec:
|
||||
output_file.write(hypothesis + "\n")
|
||||
output_file.flush()
|
||||
|
||||
|
||||
def calculate_rouge(output_lns, reference_lns, score_path):
|
||||
score_file = Path(score_path).open("w")
|
||||
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
|
||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||
scores = scorer.score(reference_ln, output_ln)
|
||||
aggregator.add_scores(scores)
|
||||
|
||||
result = aggregator.aggregate()
|
||||
score_file.write(
|
||||
"ROUGE_1: \n{} \n\n ROUGE_2: \n{} \n\n ROUGE_L: \n{} \n\n".format(
|
||||
result["rouge1"], result["rouge2"], result["rougeL"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_generate():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"model_size",
|
||||
type=str,
|
||||
help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
|
||||
default="t5-base",
|
||||
)
|
||||
parser.add_argument(
|
||||
"input_path", type=str, help="like cnn_dm/test_articles_input.txt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path", type=str, help="where to save summaries",
|
||||
)
|
||||
parser.add_argument("reference_path", type=str, help="like cnn_dm/test_reference_summaries.txt")
|
||||
parser.add_argument(
|
||||
"score_path", type=str, help="where to save the rouge score",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
|
||||
source_lns = [x.rstrip() for x in open(args.input_path).readlines()]
|
||||
|
||||
generate_summaries(source_lns, args.output_path, args.model_size, args.batch_size, args.device)
|
||||
|
||||
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
||||
|
||||
calculate_rouge(output_lns, reference_lns, args.score_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_generate()
|
@ -1,44 +0,0 @@
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from .evaluate_cnn import run_generate
|
||||
|
||||
|
||||
output_file_name = "output_t5_sum.txt"
|
||||
score_file_name = "score_t5_sum.txt"
|
||||
|
||||
articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class TestT5Examples(unittest.TestCase):
|
||||
def test_t5_cli(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
|
||||
with tmp.open("w") as f:
|
||||
f.write("\n".join(articles))
|
||||
|
||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
|
||||
score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"
|
||||
|
||||
testargs = [
|
||||
"evaluate_cnn.py",
|
||||
"patrickvonplaten/t5-tiny-random",
|
||||
str(tmp),
|
||||
str(output_file_name),
|
||||
str(tmp),
|
||||
str(score_file_name),
|
||||
]
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
self.assertTrue(Path(output_file_name).exists())
|
||||
self.assertTrue(Path(score_file_name).exists())
|
@ -146,3 +146,34 @@ class TestBartExamples(unittest.TestCase):
|
||||
# show that targets were truncated
|
||||
self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated
|
||||
self.assertGreater(max_len_target, trunc_target) # Truncated
|
||||
|
||||
|
||||
class TestT5Examples(unittest.TestCase):
|
||||
def test_t5_cli(self):
|
||||
output_file_name = "output_t5_sum.txt"
|
||||
score_file_name = "score_t5_sum.txt"
|
||||
articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
|
||||
with tmp.open("w", encoding="utf-8") as f:
|
||||
f.write("\n".join(articles))
|
||||
|
||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
|
||||
score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"
|
||||
|
||||
testargs = [
|
||||
"evaluate_cnn.py",
|
||||
str(tmp),
|
||||
str(output_file_name),
|
||||
"patrickvonplaten/t5-tiny-random",
|
||||
"--reference_path",
|
||||
str(tmp),
|
||||
"--score_path",
|
||||
str(score_file_name),
|
||||
]
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
self.assertTrue(Path(output_file_name).exists())
|
||||
self.assertTrue(Path(score_file_name).exists())
|
Loading…
Reference in New Issue
Block a user