[examples] SummarizationModule improvements (#4951)

This commit is contained in:
Sam Shleifer 2020-06-17 13:51:34 -04:00 committed by GitHub
parent cd40f6564e
commit 043f9f51f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1465 additions and 348 deletions

View File

@ -2,6 +2,8 @@ import argparse
import logging
import os
import random
from pathlib import Path
from typing import Any, Dict
import numpy as np
import pytorch_lightning as pl
@ -13,10 +15,13 @@ from transformers import (
AutoModel,
AutoModelForPreTraining,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
get_linear_schedule_with_warmup,
)
@ -31,6 +36,8 @@ MODEL_MODES = {
"pretraining": AutoModelForPreTraining,
"token-classification": AutoModelForTokenClassification,
"language-modeling": AutoModelWithLMHead,
"summarization": AutoModelForSeq2SeqLM,
"translation": AutoModelForSeq2SeqLM,
}
@ -38,33 +45,59 @@ def set_seed(args: argparse.Namespace):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
if args.gpus > 0:
torch.cuda.manual_seed_all(args.seed)
class BaseTransformer(pl.LightningModule):
def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs):
def __init__(
self,
hparams: argparse.Namespace,
num_labels=None,
mode="base",
config=None,
tokenizer=None,
model=None,
**config_kwargs
):
"Initialize a model."
super().__init__()
self.hparams = hparams
self.step_count = 0
self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir)
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
self.config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=cache_dir,
**config_kwargs,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
cache_dir=cache_dir,
)
self.model = MODEL_MODES[mode].from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
config=self.config,
cache_dir=cache_dir,
)
if config is None:
self.config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=cache_dir,
**config_kwargs,
)
else:
self.config: PretrainedConfig = config
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
cache_dir=cache_dir,
)
else:
self.tokenizer: PreTrainedTokenizer = tokenizer
if model is None:
self.model_type = MODEL_MODES[mode]
self.model = self.model_type.from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
config=self.config,
cache_dir=cache_dir,
)
else:
self.model_type = None
self.model = model
def load_hf_checkpoint(self, *args, **kwargs):
self.model = self.model_type.from_pretrained(*args, **kwargs)
def is_logger(self):
return self.trainer.proc_rank <= 0
@ -138,6 +171,15 @@ class BaseTransformer(pl.LightningModule):
),
)
@pl.utilities.rank_zero_only
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_path = self.output_dir.joinpath("best_tfmr")
save_path.mkdir(exist_ok=True)
self.model.config.save_step = self.step_count
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
self.tfmr_ckpts[self.step_count] = save_path
@staticmethod
def add_model_specific_args(parser, root_dir):
parser.add_argument(
@ -152,7 +194,7 @@ class BaseTransformer(pl.LightningModule):
)
parser.add_argument(
"--tokenizer_name",
default="",
default=None,
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
@ -165,7 +207,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
)
@ -199,7 +241,8 @@ class LoggingCallback(pl.Callback):
writer.write("{} = {}\n".format(key, str(metrics[key])))
def add_generic_args(parser, root_dir):
def add_generic_args(parser, root_dir) -> None:
# TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument(
"--output_dir",
default=None,
@ -221,8 +264,8 @@ def add_generic_args(parser, root_dir):
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--n_gpu", type=int, default=1)
parser.add_argument("--fast_dev_run", action="store_true")
parser.add_argument("--gpus", type=int, default=1)
parser.add_argument("--n_tpu_cores", type=int, default=0)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
@ -235,28 +278,32 @@ def add_generic_args(parser, root_dir):
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
parser.add_argument("--val_check_interval", default=1.0, type=float)
def generic_train(model: BaseTransformer, args: argparse.Namespace):
def generic_train(
model: BaseTransformer,
args: argparse.Namespace,
early_stopping_callback=False,
logger=True, # can pass WandbLogger() here
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
**extra_train_kwargs
):
# init model
set_seed(args)
odir = Path(model.hparams.output_dir)
odir.mkdir(exist_ok=True)
if checkpoint_callback is None:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
)
if logging_callback is None:
logging_callback = LoggingCallback()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
)
train_params = dict(
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.n_gpu,
max_epochs=args.num_train_epochs,
early_stop_callback=False,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
callbacks=[LoggingCallback()],
)
train_params = {}
if args.fp16:
train_params["use_amp"] = args.fp16
@ -269,12 +316,27 @@ def generic_train(model: BaseTransformer, args: argparse.Namespace):
train_params["num_tpu_cores"] = args.n_tpu_cores
train_params["gpus"] = 0
if args.n_gpu > 1:
if args.gpus > 1:
train_params["distributed_backend"] = "ddp"
trainer = pl.Trainer(**train_params)
trainer = pl.Trainer(
logger=logger,
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.gpus,
max_epochs=args.num_train_epochs,
early_stop_callback=early_stopping_callback,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
callbacks=[logging_callback] + extra_callbacks,
fast_dev_run=args.fast_dev_run,
val_check_interval=args.val_check_interval,
weights_summary=None,
resume_from_checkpoint=args.resume_from_checkpoint,
**train_params,
)
if args.do_train:
trainer.fit(model)
trainer.logger.log_hyperparams(args)
trainer.logger.save()
return trainer

View File

@ -5,5 +5,6 @@ psutil
sacrebleu
rouge-score
tensorflow_datasets
pytorch-lightning==0.7.3 # April 10, 2020 release
pytorch-lightning==0.7.6
matplotlib
git-python==1.0.3

View File

@ -1,47 +1,70 @@
### Get CNN Data
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
### Data
CNN/DailyMail data
```bash
cd examples/summarization
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz
export CNN_DIR=${PWD}/cnn_dm
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
XSUM Data:
```bash
cd examples/summarization
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
### Evaluation
To create summaries for each article in dataset, run:
```bash
python evaluate_cnn.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
```
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Training
Run/modify `finetune_bart.sh` or `finetune_t5.sh`
Run/modify `finetune.sh`
### Stanford CoreNLP Setup
The following command should work on a 16GB GPU:
```bash
export me=`git config user.name`
./finetune.sh \
--data_dir $XSUM_DIR \
--train_batch_size=1 \
--eval_batch_size=1 \
--output_dir="$me"_xsum_results \
--num_train_epochs 1
```
ptb_tokenize () {
cat $1 | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > $2
}
sudo apt install openjdk-8-jre-headless
sudo apt-get install ant
wget http://nlp.stanford.edu/software/stanford-corenlp-full-2018-10-05.zip
unzip stanford-corenlp-full-2018-10-05.zip
cd stanford-corenlp-full-2018-10-05
export CLASSPATH=stanford-corenlp-3.9.2.jar:stanford-corenlp-3.9.2-models.jar
```
Then run `ptb_tokenize` on `test.target` and your generated hypotheses.
### Rouge Setup
Install `files2rouge` following the instructions at [here](https://github.com/pltrdy/files2rouge).
I also needed to run `sudo apt-get install libxml-parser-perl`
Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours, requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see below)
- `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `bart-large-cnn` if you want long summaries and `bart-large-xsum` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
```python
from files2rouge import files2rouge
from files2rouge import settings
files2rouge.run(<path_to_tokenized_hypo>,
<path_to_tokenized_target>,
saveto='rouge_output.txt')
### XSUM Shared Task
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
Here is an example command
```bash
export me=`git config user.name`
./finetune.sh \
--data_dir $XSUM_DIR \
--output_dir "$me"_xsum_frozen_embs \
--logger wandb_shared \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6
```
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_summarization/table?workspace=user-)

View File

@ -0,0 +1,85 @@
import logging
import os
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
def count_trainable_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params
logger = logging.getLogger(__name__)
class Seq2SeqLoggingCallback(pl.Callback):
def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
) -> None:
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
if not pl_module.is_logger():
return
metrics = trainer.callback_metrics
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
# Log results
od = Path(pl_module.hparams.output_dir)
if type_path == "test":
results_file = od / "test_results.txt"
generations_file = od / "test_generations.txt"
else:
results_file = od / f"{type_path}_results_{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations_{trainer.global_step:05d}.txt"
with open(results_file, "a+") as writer:
for key in sorted(metrics):
if key in ["log", "progress_bar", "preds"]:
continue
val = metrics[key]
if isinstance(val, torch.Tensor):
val = val.item()
msg = f"{key}: {val:.6f}\n"
writer.write(msg)
if not save_generations:
return
if "preds" in metrics:
content = "\n".join(metrics["preds"])
generations_file.open("w+").write(content)
@rank_zero_only
def on_train_start(self, trainer, pl_module):
try:
npars = pl_module.model.model.num_parameters()
except AttributeError:
npars = pl_module.model.num_parameters()
n_trainable_pars = count_trainable_parameters(pl_module)
# mp stands for million parameters
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
return self._write_logs(trainer, pl_module, "val")
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
return self._write_logs(trainer, pl_module, "test")
def get_rouge2_checkpoint_callback(output_dir):
"""Saves the best model by validation ROUGE2 score."""
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, "{val_avg_rouge2:.4f}-{step_count}"),
monitor="val_rouge",
mode="max",
save_top_k=1,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
)
return checkpoint_callback

View File

@ -0,0 +1,448 @@
import argparse
import gc
import os
from pathlib import Path
from typing import List
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
from lightning_base import generic_train
from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration
try:
from .finetune import SummarizationModule
from .initialization_utils import init_student, copy_layers
from .utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
from .finetune import main as ft_main
except ImportError:
from finetune import SummarizationModule
from finetune import main as ft_main
from initialization_utils import init_student, copy_layers
from utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
class SummarizationDistiller(SummarizationModule):
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams):
assert Path(hparams.data_dir).exists()
d_layers_to_copy, student, student_cfg, teacher = self.pre_init(hparams)
super().__init__(hparams, model=student, config=student_cfg)
self.teacher = teacher
use_task_specific_params(self.teacher, "summarization")
freeze_params(self.teacher)
self.sanity_check_gradients()
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.temperature = 2.0
self.alpha_mlm = hparams.alpha_mlm
self.alpha_ce = hparams.alpha_ce
self.alpha_hid = hparams.alpha_hid
# self.alpha_cos = hparams.alpha_cos
self.alpha_encoder_loss = self.hparams.alpha_encoder_loss
gc.collect()
torch.cuda.empty_cache()
def sanity_check_gradients(self):
assert_all_frozen(self.teacher)
assert_all_frozen(self.model.model.decoder.embed_tokens)
assert_all_frozen(self.model.model.encoder.embed_tokens)
if self.different_encoder:
assert any_requires_grad(self.model.model.encoder)
else:
freeze_params(self.model.model.encoder)
del self.teacher.model.encoder
def pre_init(self, hparams):
# Dump empty student model at a path, then call from_pretrained on it
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
student_updates = {
"decoder_layers": hparams.student_decoder_layers,
"encoder_layers": hparams.student_encoder_layers,
}
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = BartConfig(**kw)
student = BartForConditionalGeneration(student_cfg)
student, _ = init_student(student, teacher)
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
Path(hparams.output_dir).mkdir(exist_ok=True)
return d_layers_to_copy, student, student_cfg, teacher
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
if teacher.config.model_type == "t5":
return self.copy_t5_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
if self.different_decoder:
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
if self.different_encoder:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.num_layers
self.different_decoder = hparams.student_decoder_layers != teacher.config.num_layers
if self.different_decoder:
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
if self.different_encoder:
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
return dataset
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
if mask is not None:
# mask has False at padding_idx
sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
s_logits_slct = torch.masked_select(student_outputs, sel_mask)
t_logits_slct = torch.masked_select(teacher_outputs, sel_mask)
else:
t_logits_slct = teacher_outputs
s_logits_slct = student_outputs
return F.mse_loss(s_logits_slct, t_logits_slct)
def calc_ce_loss(self, mask, s_logits, t_logits):
if mask is not None:
# mask has False at padding_idx
sel_mask = mask[:, :, None].expand_as(s_logits)
s_logits_slct = torch.masked_select(
s_logits, sel_mask
) # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct = torch.masked_select(
t_logits, sel_mask
) # (bs * seq_length * voc_size) modulo the 1s in mask
else:
t_logits_slct = t_logits
s_logits_slct = s_logits # (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = (
self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
return loss_ce, s_logits_slct, t_logits_slct
def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)"
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
self.opt = optimizer
return [optimizer]
@staticmethod
def add_model_specific_args(parser, root_dir):
SummarizationModule.add_model_specific_args(parser, root_dir)
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
# parser.add_argument("--alpha_cos", default=0.0, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument(
"--student_decoder_layers", default=12, type=int, required=False,
)
parser.add_argument(
"--student_encoder_layers", default=12, type=int, required=False,
)
parser.add_argument(
"--no_teacher", action="store_true", default=False,
)
parser.add_argument( # TODO: remove
"--enc_only", action="store_true", default=False,
)
return parser
def _step(self, batch):
# assert is_frozen(self.teacher)
pad_token_id = self.tokenizer.pad_token_id
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = y[:, :-1].contiguous()
labels = y[:, 1:].clone()
labels[y[:, 1:] == pad_token_id] = -100
# noinspection PyCallingNonCallable
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
input_ids,
attention_mask=src_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
output_hidden_states=True,
output_attentions=False,
)
def zero_tensor():
return torch.tensor(0.0).type_as(sloss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder:
with torch.no_grad():
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder(
input_ids, attention_mask=src_mask, output_hidden_states=True
)
if self.hparams.alpha_encoder_loss > 0:
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
hid_loss_enc = self.calc_hidden_loss(
src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
)
teacher_enc_outputs = (enc_outputs,)
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
with torch.no_grad():
tloss, tlogits, tdec_hidden, _ = self.teacher(
input_ids,
attention_mask=src_mask,
encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
output_hidden_states=True,
)
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
blended_loss = (
self.alpha_ce * loss_ce
+ self.alpha_mlm * sloss
+ self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
)
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
assert not isinstance(
hidden_states, torch.Tensor
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}"
assert not isinstance(
hidden_states_T, torch.Tensor
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
mask = attention_mask.to(hidden_states[0])
valid_count = mask.sum() * hidden_states[0].size(-1)
hidden_losses = [
(F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum()
/ valid_count
for i, j in enumerate(matches)
]
return sum(hidden_losses)
class T5SummarizationDistiller(SummarizationDistiller):
def pre_init(self, hparams):
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
n_layer = hparams.student_decoder_layers
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
student_updates = {"num_layers": n_layer}
hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = T5Config(**kw)
student = T5ForConditionalGeneration(student_cfg)
student, _ = init_student(student, teacher)
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
Path(hparams.output_dir).mkdir(exist_ok=True)
task_specific_params = student.config.task_specific_params
if task_specific_params is not None:
student.config.update(task_specific_params.get("summarization", {}))
return d_layers_to_copy, student, student_cfg, teacher
def freeze_embeds(self):
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
def sanity_check_gradients(self):
"""T5"""
assert_all_frozen(self.teacher)
assert_all_frozen(self.model.decoder.embed_tokens)
assert_all_frozen(self.model.encoder.embed_tokens)
if self.different_encoder:
assert any_requires_grad(self.model.encoder)
else:
freeze_params(self.model.encoder)
del self.teacher.model.encoder
if self.different_decoder:
assert any_requires_grad(self.model.decoder)
else:
freeze_params(self.model.decoder) # TODO(SS): very suspicious
def _step(self, batch):
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = y[:, :-1].contiguous()
labels = y[:, 1:].clone()
labels[y[:, 1:] == pad_token_id] = -100
# noinspection PyCallingNonCallable
dec_mask = decoder_input_ids.ne(pad_token_id)
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
source_ids,
attention_mask=source_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
output_hidden_states=True,
output_attentions=False,
use_cache=False,
)
def zero_tensor():
return torch.tensor(0.0).type_as(sloss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder:
with torch.no_grad():
teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder(
source_ids, attention_mask=source_mask, output_hidden_states=True, use_cache=False,
)
if self.hparams.alpha_encoder_loss > 0:
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask)
hid_loss_enc = self.calc_hidden_loss(
source_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
)
teacher_enc_outputs = (enc_outputs,)
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
with torch.no_grad():
tloss, tlogits, tdec_hidden, _ = self.teacher(
source_ids,
attention_mask=source_mask,
encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
output_hidden_states=True,
use_cache=False,
)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
blended_loss = (
self.alpha_ce * loss_ce
+ self.alpha_mlm * sloss
+ self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
)
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
def create_module(args):
t5 = "t5" in args.model_name_or_path
if args.no_teacher:
assert not args.enc_only
module_cls = SummarizationModule
elif t5:
module_cls = T5SummarizationDistiller
elif args.enc_only:
raise ValueError("Deleted that")
else:
module_cls = SummarizationDistiller
args.setup_cls: str = module_cls.__name__
model = module_cls(args)
return model
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
exp_dir = ckpt_path.parent
if dest_dir is None:
dest_dir = exp_dir
clash = list(dest_dir.glob("test_generations*"))
if clash:
print(f"SKIPPING to avoid overwriting {clash}")
ckpt = torch.load(ckpt_path, map_location="cpu")
if "hparams" in ckpt:
args = argparse.Namespace(**ckpt["hparams"])
else:
args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl"))
args.resume_from_checkpoint = str(ckpt_path)
args.do_train = False
args.output_dir = str(dest_dir)
args.n_gpu = 1
args.eval_batch_size = 16
Path(args.output_dir).mkdir(exist_ok=True)
model = create_module(args)
trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False)
trainer.test(model)
def get_layers_to_copy(n_to_get, tot):
all_layers = list(range(tot))
if tot == 12: # Alternating for special cases
layers_to_copy = { # maps # layers in student -> which teacher layers to copy
6: [0, 2, 4, 7, 9, 11],
1: [11],
3: [0, 6, 11],
2: [0, 11],
4: [0, 4, 8, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: all_layers,
}
return layers_to_copy[n_to_get]
else:
return all_layers[:n_to_get]
def distill_main(args):
Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
model = create_module(args)
return ft_main(args, model=model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
distill_main(args)

View File

@ -1,100 +0,0 @@
import argparse
from pathlib import Path
import torch
from rouge_score import rouge_scorer, scoring
from tqdm import tqdm
from transformers import AutoModelWithLMHead, AutoTokenizer
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
):
fout = Path(out_file).open("w", encoding="utf-8")
model = AutoModelWithLMHead.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# update config with summarization specific params
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
model.config.update(task_specific_params.get("summarization", {}))
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
device
)
summaries = model.generate(**dct)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
def calculate_rouge(output_lns, reference_lns, score_path):
score_file = Path(score_path).open("w")
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
aggregator = scoring.BootstrapAggregator()
for reference_ln, output_ln in zip(reference_lns, output_lns):
scores = scorer.score(reference_ln, output_ln)
aggregator.add_scores(scores)
result = aggregator.aggregate()
score_file.write(
"ROUGE_1: \n{} \n\n ROUGE_2: \n{} \n\n ROUGE_L: \n{} \n\n".format(
result["rouge1"], result["rouge2"], result["rougeL"]
)
)
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument(
"input_path", type=str, help="like cnn_dm/test.source or cnn_dm/test_articles_input.txt",
)
parser.add_argument(
"output_path", type=str, help="where to save summaries",
)
parser.add_argument(
"model_name",
type=str,
default="facebook/bart-large-cnn",
help="like bart-large-cnn,'t5-small', 't5-base', 't5-large', 't5-3b', 't5-11b",
)
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument(
"--score_path", type=str, required=False, help="where to save the rouge score",
)
parser.add_argument(
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
)
parser.add_argument(
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
)
args = parser.parse_args()
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
if args.score_path is not None:
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
calculate_rouge(output_lns, reference_lns, args.score_path)
if __name__ == "__main__":
run_generate()

View File

@ -3,91 +3,169 @@ import glob
import logging
import os
import time
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import get_linear_schedule_with_warmup
try:
from .utils import SummarizationDataset
from .utils import (
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
save_git_info,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
)
from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
except ImportError:
from utils import SummarizationDataset
from utils import (
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
save_git_info,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
)
from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
logger = logging.getLogger(__name__)
class SummarizationTrainer(BaseTransformer):
class SummarizationModule(BaseTransformer):
mode = "summarization"
loss_names = ["loss"]
mode = "language-modeling"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
use_task_specific_params(self.model, "summarization")
save_git_info(self.hparams.output_dir)
self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
self.step_count = 0
self.metrics = {"train": [], "val": [], "test": []}
def __init__(self, hparams):
super().__init__(hparams, num_labels=None, mode=self.mode)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
max_target_length=self.hparams.max_target_length,
prefix=self.model.config.prefix or "",
)
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
"test": self.hparams.n_test,
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None):
return self.model(
input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, lm_labels=lm_labels,
self.target_lens = {
"train": self.hparams.max_target_length,
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
if self.hparams.freeze_embeds:
self.freeze_embeds()
if self.hparams.freeze_encoder:
freeze_params(self.model.model.encoder) # TODO: this will break for t5
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
if self.model.config.model_type == "bart":
freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
else:
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def ids_to_clean_text(self, generated_ids: List[int]):
gen_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return lmap(str.strip, gen_text)
def _step(self, batch):
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"]
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
y_ids = y[:, :-1].contiguous()
lm_labels = y[:, 1:].clone()
lm_labels[y[:, 1:] == pad_token_id] = -100
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,)
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
loss = outputs[0]
return (loss,)
return loss
def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch)
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
return {"loss": loss_tensors[0], "log": logs}
def training_step(self, batch, batch_idx):
loss = self._step(batch)
def validation_step(self, batch, batch_idx) -> Dict:
return self._generative_step(batch)
tensorboard_logs = {"train_loss": loss}
return {"loss": loss, "log": tensorboard_logs}
def validation_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
rouges.update({k: v.item() for k, v in losses.items()})
losses.update(rouges)
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
metrics["step_count"] = self.step_count
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs])
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}
def validation_step(self, batch, batch_idx):
loss = self._step(batch)
return {"val_loss": loss}
def save_metrics(self, metrics, prefix) -> None:
self.metrics[prefix].append(metrics)
pickle_save(self.metrics, self.metrics_save_path)
def validation_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
tensorboard_logs = {"val_loss": avg_loss}
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
def test_step(self, batch, batch_idx):
def _generative_step(self, batch):
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
# NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
generated_ids = self.model.generate(
input_ids=source_ids,
attention_mask=source_mask,
num_beams=1,
max_length=80,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True,
use_cache=True,
)
preds = [
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for g in generated_ids
]
target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]
loss = self._step(batch)
# TODO(SS): task specific params
return {"val_loss": loss, "preds": preds, "target": target}
t0 = time.time()
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
gen_time = time.time() - t0
preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y)
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = calculate_rouge(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
return base_metrics
def test_step(self, batch, batch_idx):
return self._generative_step(batch)
def test_end(self, outputs):
return self.validation_end(outputs)
return self.validation_end(outputs, prefix="test")
def test_epoch_end(self, outputs):
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
@ -102,15 +180,43 @@ class SummarizationTrainer(BaseTransformer):
return self.test_end(outputs)
def validation_epoch_end(self, outputs):
self.validation_end(outputs, "val")
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = SummarizationDataset(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
max_target_length=max_target_length,
**self.dataset_kwargs,
)
return dataset
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle)
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
num_workers=self.num_workers,
sampler=sampler,
)
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.gradient_accumulation_steps
* float(self.hparams.num_train_epochs)
)
@ -129,7 +235,7 @@ class SummarizationTrainer(BaseTransformer):
@staticmethod
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
# Add BART specific options
add_generic_args(parser, root_dir)
parser.add_argument(
"--max_source_length",
default=1024,
@ -144,41 +250,82 @@ class SummarizationTrainer(BaseTransformer):
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--val_max_target_length",
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--test_max_target_length",
default=142,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.",
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
return parser
def main(args):
def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if model is None:
model: BaseTransformer = SummarizationModule(args)
if (
args.logger == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger == "wandb":
from pytorch_lightning.loggers import WandbLogger
# If output_dir not provided, a folder will be generated in pwd
if not args.output_dir:
args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",)
os.makedirs(args.output_dir)
model = SummarizationTrainer(args)
trainer = generic_train(model, args)
logger = WandbLogger(name=model.output_dir.name)
elif args.logger == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
# See https://github.com/huggingface/transformers/issues/3159
# pl use this format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
model = model.load_from_checkpoint(checkpoints[-1])
trainer.test(model)
# TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
logger=logger,
# TODO: early stopping callback seems messed up
)
if not args.do_predict:
return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]
trainer.logger.log_hyperparams(model.hparams)
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_generic_args(parser, os.getcwd())
parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,23 @@
export OUTPUT_DIR=bart_cnn_finetune
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
# --model_name_or_path=t5-base for t5
python finetune.py \
--model_name_or_path=facebook/bart-large \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
--do_train \
--do_predict \
--n_val 1000 \
--val_check_interval 0.1 \
--sortish_sampler \
--max_target_length=56 \
$@

View File

@ -1,18 +0,0 @@
export OUTPUT_DIR_NAME=bart_sum
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--data_dir=./cnn-dailymail/cnn_dm \
--model_name_or_path=bart-large \
--learning_rate=3e-5 \
--train_batch_size=4 \
--eval_batch_size=4 \
--output_dir=$OUTPUT_DIR \
--do_train $@

View File

@ -0,0 +1,20 @@
from typing import List
from torch import nn
def init_student(student, teacher):
teacher_state_dict = teacher.state_dict()
info = student.load_state_dict(teacher_state_dict, strict=False)
assert info.missing_keys == [], info.missing_keys
return student, info
def copy_decoder_layers(teacher, student, l2copy=[0, 2, 4, 7, 9, 11]):
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, l2copy)
def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList, layers_to_copy: List) -> None:
layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in layers_to_copy])
assert len(student_layers) == len(layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}"
student_layers.load_state_dict(layers_to_copy.state_dict())

View File

@ -0,0 +1,12 @@
#CNN_DIR = /home/shleifer/transformers_fork/examples/summarization/bart/cnn_dm
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 \
--val_check_interval 0.1 \
$@

View File

@ -0,0 +1,78 @@
import argparse
import json
from pathlib import Path
import torch
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try:
from .finetune import calculate_rouge, use_task_specific_params
except ImportError:
from finetune import calculate_rouge, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
) -> None:
fout = Path(out_file).open("w", encoding="utf-8")
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# update config with summarization specific params
use_task_specific_params(model, "summarization")
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
device
)
summaries = model.generate(**dct)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("output_path", type=str, help="where to save summaries")
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
generate_summaries(
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
)
if args.score_path is not None:
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
rouge: dict = calculate_rouge(output_lns, reference_lns)
json.dump(rouge, open("score_path", "w+"))
if __name__ == "__main__":
run_generate()

View File

@ -7,28 +7,40 @@ import unittest
from pathlib import Path
from unittest.mock import patch
import torch
from torch.utils.data import DataLoader
from transformers import BartTokenizer
from .evaluate_cnn import run_generate
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .utils import SummarizationDataset
from .run_eval import generate_summaries, run_generate
from .utils import SummarizationDataset, lmap, pickle_load
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
DEFAULT_ARGS = {
FP16_EVER = False
CHEAP_ARGS = {
"logger": "default",
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
"tgt_suffix": "",
"resume_from_checkpoint": None,
"sortish_sampler": True,
"student_decoder_layers": 1,
"val_check_interval": 1.0,
"output_dir": "",
"fp16": False,
"no_teacher": False,
"fp16_opt_level": "O1",
"n_gpu": 1,
"gpus": 1 if torch.cuda.is_available() else 0,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": False,
"do_predict": True,
"gradient_accumulation_steps": 1,
"server_ip": "",
"server_port": "",
@ -36,7 +48,7 @@ DEFAULT_ARGS = {
"model_type": "bart",
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "",
"tokenizer_name": "facebook/bart-large",
"cache_dir": "",
"do_lower_case": False,
"learning_rate": 3e-05,
@ -48,6 +60,17 @@ DEFAULT_ARGS = {
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
"val_max_target_length": 12,
"test_max_target_length": 12,
"fast_dev_run": False,
"no_cache": False,
"n_train": -1,
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"alpha_loss_encoder": 0.0,
"freeze_encoder": False,
"auto_scale_batch_size": False,
}
@ -56,6 +79,9 @@ def _dump_articles(path: Path, articles: list):
f.write("\n".join(articles))
BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute()
def make_test_data_dir():
tmp_dir = Path(tempfile.gettempdir())
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
@ -66,6 +92,169 @@ def make_test_data_dir():
return tmp_dir
@unittest.skip("These wont' pass until hidden_states kwarg is merged.")
class TestSummarizationDistiller(unittest.TestCase):
@classmethod
def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
def test_bdc_multigpu(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
no_teacher=True,
freeze_encoder=True,
gpus=2,
sortish_sampler=False,
)
self._bart_distiller_cli(updates)
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
def test_bdc_fp16(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
alpha_hid=3.0,
freeze_encoder=True,
gpus=1,
fp16=FP16_EVER,
fp16_opt_level="O1",
)
self._bart_distiller_cli(updates)
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
def test_bdc_t5_eval_fp16(self):
updates = dict(
fp16=FP16_EVER,
gpus=1,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
do_train=False,
do_predict=True,
tokenizer_name=None,
no_teacher=True,
)
self._bart_distiller_cli(updates, check_contents=False)
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
def test_bdc_t5_train_fp16(self):
updates = dict(
fp16=FP16_EVER,
gpus=1,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
do_train=True,
do_predict=True,
tokenizer_name="patrickvonplaten/t5-tiny-random",
no_teacher=True,
)
self._bart_distiller_cli(updates)
def test_bdc_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,)
self._bart_distiller_cli(updates)
def test_bdc_yes_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1,)
self._bart_distiller_cli(updates)
def test_bdc_checkpointing(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
)
model = self._bart_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), len(ckpts))
new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(new_transformer_ckpts), 1)
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
out_path = tempfile.mktemp()
generate_summaries(examples, out_path, new_transformer_ckpts[0].parent)
self.assertTrue(Path(out_path).exists())
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
def test_bdc_t5(self):
updates = dict(
student_encoder_layers=1,
student_decoder_layers=1,
alpha_hid=2.0,
teacher="patrickvonplaten/t5-tiny-random",
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
tokenizer_name="patrickvonplaten/t5-tiny-random",
)
self._bart_distiller_cli(updates)
def test_bdc_t5_eval(self):
updates = dict(
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
do_train=False,
do_predict=True,
tokenizer_name="patrickvonplaten/t5-tiny-random",
no_teacher=True,
)
self._bart_distiller_cli(updates, check_contents=False)
def _bart_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
model_type="bart",
train_batch_size=1,
eval_batch_size=2,
num_train_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
gpus=1 if torch.cuda.is_available() else 0,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
alpha_encoder_loss=0.4,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d))
if not check_contents:
return model
contents = os.listdir(output_dir)
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
contents = {os.path.basename(p) for p in contents}
self.assertIn(ckpt_name, contents)
self.assertIn("metrics.pkl", contents)
self.assertIn("test_generations.txt", contents)
self.assertIn("val_generations_1.txt", contents)
self.assertIn("val_1_results.txt", contents)
self.assertIn("test_results.txt", contents)
# self.assertEqual(len(contents), 15)
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
import pandas as pd
val_df = pd.DataFrame(metrics["val"])
train_df = pd.DataFrame(metrics["train"])
test_df = pd.DataFrame(metrics["test"])
desired_n_evals = args_d["num_train_epochs"] * 2 + 1
self.assertEqual(val_df.shape[0], desired_n_evals) #
self.assertEqual(test_df.shape[1], val_df.shape[1])
self.assertEqual(train_df.shape[0], 0)
return model
class TestBartExamples(unittest.TestCase):
@classmethod
def setUpClass(cls):
@ -79,49 +268,31 @@ class TestBartExamples(unittest.TestCase):
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(tmp, articles)
testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path(output_file_name).exists())
os.remove(Path(output_file_name))
def test_bart_run_sum_cli(self):
args_d: dict = DEFAULT_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
)
main(argparse.Namespace(**args_d))
args_d.update({"do_train": False, "do_predict": True})
main(argparse.Namespace(**args_d))
contents = os.listdir(output_dir)
expected_contents = {
"checkpointepoch=0.ckpt",
"test_results.txt",
}
created_files = {os.path.basename(p) for p in contents}
self.assertSetEqual(expected_contents, created_files)
def test_t5_run_sum_cli(self):
args_d: dict = DEFAULT_ARGS.copy()
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
tokenizer_name=None, # "patrickvonplaten/t5-tiny-random",
train_batch_size=2,
eval_batch_size=2,
n_gpu=0,
gpus=0,
output_dir=output_dir,
do_predict=True,
)
main(argparse.Namespace(**args_d))
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
main(args)
def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir())
@ -138,42 +309,16 @@ class TestBartExamples(unittest.TestCase):
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
self.assertEqual(batch["source_mask"].shape, batch["source_ids"].shape)
self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape)
# show that articles were trimmed.
self.assertEqual(batch["source_ids"].shape[1], max_len_source)
self.assertGreater(20, batch["source_ids"].shape[1]) # trimmed significantly
self.assertEqual(batch["input_ids"].shape[1], max_len_source)
self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly
# show that targets were truncated
self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated
self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated
self.assertGreater(max_len_target, trunc_target) # Truncated
class TestT5Examples(unittest.TestCase):
def test_t5_cli(self):
output_file_name = "output_t5_sum.txt"
score_file_name = "score_t5_sum.txt"
articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
with tmp.open("w", encoding="utf-8") as f:
f.write("\n".join(articles))
output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"
testargs = [
"evaluate_cnn.py",
str(tmp),
str(output_file_name),
"patrickvonplaten/t5-tiny-random",
"--reference_path",
str(tmp),
"--score_path",
str(score_file_name),
]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path(output_file_name).exists())
self.assertTrue(Path(score_file_name).exists())
def list_to_text_file(lst, path):
dest = Path(path)
dest.open("w+").writelines(lst)

View File

@ -1,20 +1,66 @@
import itertools
import json
import os
import pickle
from pathlib import Path
from typing import Dict, Iterable, List
import git
import numpy as np
import torch
from torch.utils.data import Dataset
from rouge_score import rouge_scorer, scoring
from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import BartTokenizer
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
def encode_file(
tokenizer,
data_path,
max_length,
pad_to_max_length=True,
return_tensors="pt",
overwrite_cache=False,
prefix="",
tok_name="",
):
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
if not overwrite_cache and cache_path.exists():
try:
examples = torch.load(cache_path)
assert isinstance(examples, list)
return examples
except Exception:
print(f"failed to load from {cache_path}, retokenizing {data_path}")
data_path = Path(data_path)
lns = lmap(str.strip, data_path.open().readlines())
lns = [prefix + text for text in lns]
assert lns, f"found empty file at {data_path}"
examples = []
with open(data_path, "r") as f:
for text in f.readlines():
tokenized = tokenizer.batch_encode_plus(
[text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
)
examples.append(tokenized)
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
tokenized = tokenizer.batch_encode_plus(
[text], # DONT ADD SPACES
max_length=max_length,
pad_to_max_length=pad_to_max_length,
add_prefix_space=True,
return_tensors=return_tensors,
)
examples.append(tokenized)
torch.save(lmap(dict, examples), cache_path.open("wb"))
return examples
def lmap(f, x):
return list(map(f, x))
T5_PREFIX = "summarize: " # HACK, fixme
def trim_batch(
input_ids, pad_token_id, attention_mask=None,
):
@ -30,15 +76,38 @@ class SummarizationDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir="./cnn-dailymail/cnn_dm/",
data_dir,
type_path="train",
max_source_length=1024,
max_target_length=56,
n_obs=None,
overwrite_cache=False,
prefix="",
):
super().__init__()
self.tokenizer = tokenizer
tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else ""
self.source = encode_file(
tokenizer,
os.path.join(data_dir, type_path + ".source"),
max_source_length,
overwrite_cache=overwrite_cache,
prefix=prefix,
tok_name=tok_name,
)
if type_path == "train":
tgt_path = os.path.join(data_dir, type_path + ".target")
else:
tgt_path = os.path.join(data_dir, type_path + ".target")
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
if n_obs is not None:
self.source = self.source[:n_obs]
self.target = self.target[:n_obs]
self.pad_token_id = tokenizer.pad_token_id
def __len__(self):
return len(self.source)
@ -47,19 +116,141 @@ class SummarizationDataset(Dataset):
source_ids = self.source[index]["input_ids"].squeeze()
target_ids = self.target[index]["input_ids"].squeeze()
src_mask = self.source[index]["attention_mask"].squeeze()
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
@staticmethod
def trim_seq2seq_batch(batch, pad_token_id):
y = trim_batch(batch["target_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"])
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
return source_ids, source_mask, y
def collate_fn(self, batch):
input_ids = torch.stack([x["source_ids"] for x in batch])
masks = torch.stack([x["source_mask"] for x in batch])
target_ids = torch.stack([x["target_ids"] for x in batch])
pad_token_id = self.tokenizer.pad_token_id
def collate_fn(self, batch) -> dict:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
return batch
@property
def src_lens(self): # Can delete?
return lmap(len, self.source)
@property
def tgt_lens(self):
return lmap(len, self.target)
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.source, batch_size)
class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def __init__(self, data, batch_size):
self.data, self.bs = data, batch_size
def key(self, i):
return len(self.data[i])
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
idxs = np.random.permutation(len(self.data))
sz = self.bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
sz = self.bs
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return iter(sort_idx)
def use_task_specific_params(model, task):
# update config with summarization specific params
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
model.config.update(task_specific_params.get(task, {}))
def pickle_load(path):
"""pickle.load(path)"""
with open(path, "rb") as f:
return pickle.load(f)
def pickle_save(obj, path):
"""pickle.dump(obj, path)"""
with open(path, "wb") as f:
return pickle.dump(obj, f)
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
def save_git_info(folder_path: str):
"""
Log commit info.
"""
repo_infos = get_git_info()
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
json.dump(repo_infos, f, indent=4)
def get_git_info():
repo = git.Repo(search_parent_directories=True)
repo_infos = {
"repo_id": str(repo),
"repo_sha": str(repo.head.object.hexsha),
"repo_branch": str(repo.active_branch),
}
return repo_infos
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
aggregator = scoring.BootstrapAggregator()
for reference_ln, output_ln in zip(reference_lns, output_lns):
scores = scorer.score(reference_ln, output_ln)
aggregator.add_scores(scores)
result = aggregator.aggregate()
return {k: v.mid.fmeasure for k, v in result.items()}
def freeze_params(model: nn.Module):
for par in model.parameters():
par.requires_grad = False
def grad_status(model: nn.Module) -> Iterable:
return (par.requires_grad for par in model.parameters())
def any_requires_grad(model: nn.Module) -> bool:
return any(grad_status(model))
def assert_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
n_require_grad = sum(lmap(int, model_grads))
npars = len(model_grads)
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
def assert_not_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
npars = len(model_grads)
assert any(model_grads), f"none of {npars} weights require grad"

View File

@ -59,7 +59,7 @@ BART_GENERATION_EXAMPLE = r"""
Examples::
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
# see ``examples/summarization/bart/run_eval.py`` for a longer example
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."