[examples] summarization/bart/finetune.py supports t5 (#3824)

renames `run_bart_sum.py` to `finetune.py`
This commit is contained in:
Sam Shleifer 2020-04-16 15:15:19 -04:00 committed by GitHub
parent 0cec4fab7d
commit f0c96fafd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 14 deletions

View File

@ -19,7 +19,7 @@ except ImportError:
logger = logging.getLogger(__name__)
class BartSystem(BaseTransformer):
class SummarizationTrainer(BaseTransformer):
mode = "language-modeling"
@ -64,18 +64,18 @@ class BartSystem(BaseTransformer):
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
def test_step(self, batch, batch_idx):
# NOTE: this generation will not use the cache.
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
# NOTE: these kwargs get more speed and lower quality summaries than those in evaluate_cnn.py.
# NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
generated_ids = self.model.generate(
source_ids,
source_mask,
input_ids=source_ids,
attention_mask=source_mask,
num_beams=1,
max_length=80,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True,
use_cache=True,
)
preds = [
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
@ -161,20 +161,20 @@ def main(args):
if not args.output_dir:
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
os.makedirs(args.output_dir)
model = BartSystem(args)
model = SummarizationTrainer(args)
trainer = generic_train(model, args)
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
BartSystem.load_from_checkpoint(checkpoints[-1])
SummarizationTrainer.load_from_checkpoint(checkpoints[-1])
trainer.test(model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_generic_args(parser, os.getcwd())
parser = BartSystem.add_model_specific_args(parser, os.getcwd())
parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)

View File

@ -8,7 +8,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py
export PYTHONPATH="../../":"${PYTHONPATH}"
python run_bart_sum.py \
python finetune.py \
--data_dir=./cnn-dailymail/cnn_dm \
--model_type=bart \
--model_name_or_path=bart-large \

View File

@ -14,7 +14,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py and utils.py
export PYTHONPATH="../../":"${PYTHONPATH}"
python run_bart_sum.py \
python finetune.py \
--data_dir=cnn_tiny/ \
--model_type=bart \
--model_name_or_path=sshleifer/bart-tiny-random \

View File

@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
from transformers import BartTokenizer
from .evaluate_cnn import run_generate
from .run_bart_sum import main
from .finetune import main
from .utils import SummarizationDataset
@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase):
args_d.update(
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
)
main(argparse.Namespace(**args_d))
args_d.update({"do_train": False, "do_predict": True})
main(argparse.Namespace(**args_d))
args = argparse.Namespace(**args_d)
main(args)
def test_t5_run_sum_cli(self):
args_d: dict = DEFAULT_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
train_batch_size=2,
eval_batch_size=2,
n_gpu=0,
output_dir=output_dir,
do_predict=True,
)
main(argparse.Namespace(**args_d))
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))
def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir())

View File

@ -15,7 +15,7 @@ wc -l cnn_articles_input_data.txt # should print 11490
wc -l cnn_articles_reference_summaries.txt # should print 11490
```
### Usage
### Generating Summaries
To create summaries for each article in dataset, run:
```bash
@ -23,3 +23,7 @@ python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summar
```
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``.
### Finetuning
Pass model_type=t5 and model `examples/summarization/bart/finetune.py`