transformers/examples/summarization/bart/run_bart_sum.py

173 lines
6.2 KiB
Python

import argparse
import glob
import logging
import os
import time
import torch
from torch.utils.data import DataLoader
from transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup
from utils import SummarizationDataset
logger = logging.getLogger(__name__)
class BartSystem(BaseTransformer):
mode = "language-modeling"
def __init__(self, hparams):
super(BartSystem, self).__init__(hparams, num_labels=None, mode=self.mode)
def forward(
self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None
):
return self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
lm_labels=lm_labels,
)
def _step(self, batch):
y = batch["target_ids"]
y_ids = y[:, :-1].contiguous()
lm_labels = y[:, 1:].clone()
lm_labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100
outputs = self(
input_ids=batch["source_ids"],
attention_mask=batch["source_mask"],
decoder_input_ids=y_ids,
lm_labels=lm_labels,
)
loss = outputs[0]
return loss
def training_step(self, batch, batch_idx):
loss = self._step(batch)
tensorboard_logs = {"train_loss": loss}
return {"loss": loss, "log": tensorboard_logs}
def validation_step(self, batch, batch_idx):
loss = self._step(batch)
return {"val_loss": loss}
def validation_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
tensorboard_logs = {"val_loss": avg_loss}
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
def test_step(self, batch, batch_idx):
generated_ids = self.model.generate(
batch["source_ids"],
attention_mask=batch["source_mask"],
num_beams=1,
max_length=80,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True,
)
preds = [
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for g in generated_ids
]
target = [
self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for t in batch["target_ids"]
]
loss = self._step(batch)
return {"val_loss": loss, "preds": preds, "target": target}
def test_end(self, outputs):
return self.validation_end(outputs)
def test_epoch_end(self, outputs):
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
# write predictions and targets for later rouge evaluation.
with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
for output_batch in outputs:
p_writer.writelines(s + "\n" for s in output_batch["preds"])
t_writer.writelines(s + "\n" for s in output_batch["target"])
p_writer.close()
t_writer.close()
return self.test_end(outputs)
def train_dataloader(self):
train_dataset = SummarizationDataset(
self.tokenizer, data_dir=self.hparams.data_dir, type_path="train", block_size=self.hparams.max_seq_length
)
dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
// self.hparams.gradient_accumulation_steps
* float(self.hparams.num_train_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self):
val_dataset = SummarizationDataset(
self.tokenizer, data_dir=self.hparams.data_dir, type_path="val", block_size=self.hparams.max_seq_length
)
return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size)
def test_dataloader(self):
test_dataset = SummarizationDataset(
self.tokenizer, data_dir=self.hparams.data_dir, type_path="test", block_size=self.hparams.max_seq_length
)
return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size)
@staticmethod
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
# Add BART specific options
parser.add_argument(
"--max_seq_length",
default=1024,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.",
)
return parser
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_generic_args(parser, os.getcwd())
parser = BartSystem.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
# If output_dir not provided, a folder will be generated in pwd
if args.output_dir is None:
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
os.makedirs(args.output_dir)
model = BartSystem(args)
trainer = generic_train(model, args)
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
BartSystem.load_from_checkpoint(checkpoints[-1])
trainer.test(model)