mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
BART for summarization training with CNN/DM using pytorch-lightning
This commit is contained in:
parent
eaabaaf750
commit
3d76df3a12
@ -14,6 +14,19 @@ python evaluate_cnn.py <path_to_test.source> cnn_test_summaries.txt
|
||||
```
|
||||
the default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
|
||||
|
||||
### Training
|
||||
|
||||
|
||||
|
||||
After downloading the CNN and Daily Mail datasets, preprocess the dataset:
|
||||
```commandline
|
||||
git clone https://github.com/artmatsak/cnn-dailymail
|
||||
cd cnn-dailymail && python make_datafiles.py ../cnn/stories/ ../dailymail/stories/
|
||||
```
|
||||
|
||||
Run the training script: `run_train.sh`
|
||||
|
||||
### Where is the code?
|
||||
The core model is in `src/transformers/modeling_bart.py`. This directory only contains examples.
|
||||
|
||||
|
172
examples/summarization/bart/run_bart_sum.py
Normal file
172
examples/summarization/bart/run_bart_sum.py
Normal file
@ -0,0 +1,172 @@
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup
|
||||
from utils import SummarizationDataset
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BartSystem(BaseTransformer):
|
||||
|
||||
mode = "language-modeling"
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(BartSystem, self).__init__(hparams, num_labels=None, mode=self.mode)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None
|
||||
):
|
||||
return self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
lm_labels=lm_labels,
|
||||
)
|
||||
|
||||
def _step(self, batch):
|
||||
y = batch["target_ids"]
|
||||
y_ids = y[:, :-1].contiguous()
|
||||
lm_labels = y[:, 1:].clone()
|
||||
lm_labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100
|
||||
outputs = self(
|
||||
input_ids=batch["source_ids"],
|
||||
attention_mask=batch["source_mask"],
|
||||
decoder_input_ids=y_ids,
|
||||
lm_labels=lm_labels,
|
||||
)
|
||||
|
||||
loss = outputs[0]
|
||||
|
||||
return loss
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self._step(batch)
|
||||
|
||||
tensorboard_logs = {"train_loss": loss}
|
||||
return {"loss": loss, "log": tensorboard_logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss = self._step(batch)
|
||||
return {"val_loss": loss}
|
||||
|
||||
def validation_end(self, outputs):
|
||||
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
||||
tensorboard_logs = {"val_loss": avg_loss}
|
||||
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
generated_ids = self.model.generate(
|
||||
batch["source_ids"],
|
||||
attention_mask=batch["source_mask"],
|
||||
num_beams=1,
|
||||
max_length=80,
|
||||
repetition_penalty=2.5,
|
||||
length_penalty=1.0,
|
||||
early_stopping=True,
|
||||
)
|
||||
preds = [
|
||||
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
for g in generated_ids
|
||||
]
|
||||
target = [
|
||||
self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
for t in batch["target_ids"]
|
||||
]
|
||||
loss = self._step(batch)
|
||||
|
||||
return {"val_loss": loss, "preds": preds, "target": target}
|
||||
|
||||
def test_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
||||
output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
|
||||
# write predictions and targets for later rouge evaluation.
|
||||
with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
|
||||
for output_batch in outputs:
|
||||
p_writer.writelines(s + "\n" for s in output_batch["preds"])
|
||||
t_writer.writelines(s + "\n" for s in output_batch["target"])
|
||||
p_writer.close()
|
||||
t_writer.close()
|
||||
|
||||
return self.test_end(outputs)
|
||||
|
||||
def train_dataloader(self):
|
||||
train_dataset = SummarizationDataset(
|
||||
self.tokenizer, data_dir=self.hparams.data_dir, type_path="train", block_size=self.hparams.max_seq_length
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
|
||||
// self.hparams.gradient_accumulation_steps
|
||||
* float(self.hparams.num_train_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
self.lr_scheduler = scheduler
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self):
|
||||
val_dataset = SummarizationDataset(
|
||||
self.tokenizer, data_dir=self.hparams.data_dir, type_path="val", block_size=self.hparams.max_seq_length
|
||||
)
|
||||
return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
def test_dataloader(self):
|
||||
test_dataset = SummarizationDataset(
|
||||
self.tokenizer, data_dir=self.hparams.data_dir, type_path="test", block_size=self.hparams.max_seq_length
|
||||
)
|
||||
return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
# Add BART specific options
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=1024,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
add_generic_args(parser, os.getcwd())
|
||||
parser = BartSystem.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
|
||||
# If output_dir not provided, a folder will be generated in pwd
|
||||
if args.output_dir is None:
|
||||
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
model = BartSystem(args)
|
||||
trainer = generic_train(model, args)
|
||||
|
||||
# Optionally, predict on dev set and write to output_dir
|
||||
if args.do_predict:
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
||||
BartSystem.load_from_checkpoint(checkpoints[-1])
|
||||
trainer.test(model)
|
23
examples/summarization/bart/run_train.sh
Executable file
23
examples/summarization/bart/run_train.sh
Executable file
@ -0,0 +1,23 @@
|
||||
# Install newest ptl.
|
||||
pip install -U git+http://github.com/PyTorchLightning/pytorch-lightning/
|
||||
|
||||
|
||||
export OUTPUT_DIR_NAME=bart_sum
|
||||
export CURRENT_DIR=${PWD}
|
||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
||||
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access transformer_base.py
|
||||
export PYTHONPATH="../../":"${PYTHONPATH}"
|
||||
|
||||
python run_bart_sum.py \
|
||||
--data_dir=./cnn-dailymail/cnn_dm \
|
||||
--model_type=bart \
|
||||
--model_name_or_path=bart-large \
|
||||
--learning_rate=3e-5 \
|
||||
--train_batch_size=4 \
|
||||
--eval_batch_size=4 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_train
|
43
examples/summarization/bart/utils.py
Normal file
43
examples/summarization/bart/utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
import os
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, tokenizer, data_dir="./cnn-dailymail/cnn_dm/", type_path="train", block_size=1024):
|
||||
super(SummarizationDataset,).__init__()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.source = []
|
||||
self.target = []
|
||||
|
||||
print("loading " + type_path + " source.")
|
||||
|
||||
with open(os.path.join(data_dir, type_path + ".source"), "r") as f:
|
||||
for text in f.readlines(): # each text is a line and a full story
|
||||
tokenized = tokenizer.batch_encode_plus(
|
||||
[text], max_length=block_size, pad_to_max_length=True, return_tensors="pt"
|
||||
)
|
||||
self.source.append(tokenized)
|
||||
f.close()
|
||||
|
||||
print("loading " + type_path + " target.")
|
||||
|
||||
with open(os.path.join(data_dir, type_path + ".target"), "r") as f:
|
||||
for text in f.readlines(): # each text is a line and a summary
|
||||
tokenized = tokenizer.batch_encode_plus(
|
||||
[text], max_length=56, pad_to_max_length=True, return_tensors="pt"
|
||||
)
|
||||
self.target.append(tokenized)
|
||||
f.close()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.source)
|
||||
|
||||
def __getitem__(self, index):
|
||||
source_ids = self.source[index]["input_ids"].squeeze()
|
||||
target_ids = self.target[index]["input_ids"].squeeze()
|
||||
|
||||
src_mask = self.source[index]["attention_mask"].squeeze() # might need to squeeze
|
||||
|
||||
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
|
@ -53,10 +53,9 @@ class BaseTransformer(pl.LightningModule):
|
||||
super(BaseTransformer, self).__init__()
|
||||
self.hparams = hparams
|
||||
self.hparams.model_type = self.hparams.model_type.lower()
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
Loading…
Reference in New Issue
Block a user