mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[examples] SummarizationModule improvements (#4951)
This commit is contained in:
parent
cd40f6564e
commit
043f9f51f9
@ -2,6 +2,8 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
@ -13,10 +15,13 @@ from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
@ -31,6 +36,8 @@ MODEL_MODES = {
|
||||
"pretraining": AutoModelForPreTraining,
|
||||
"token-classification": AutoModelForTokenClassification,
|
||||
"language-modeling": AutoModelWithLMHead,
|
||||
"summarization": AutoModelForSeq2SeqLM,
|
||||
"translation": AutoModelForSeq2SeqLM,
|
||||
}
|
||||
|
||||
|
||||
@ -38,33 +45,59 @@ def set_seed(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
if args.gpus > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
hparams: argparse.Namespace,
|
||||
num_labels=None,
|
||||
mode="base",
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
):
|
||||
"Initialize a model."
|
||||
|
||||
super().__init__()
|
||||
self.hparams = hparams
|
||||
self.step_count = 0
|
||||
self.tfmr_ckpts = {}
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
self.model = MODEL_MODES[mode].from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=self.config,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
if config is None:
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
else:
|
||||
self.config: PretrainedConfig = config
|
||||
if tokenizer is None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
if model is None:
|
||||
self.model_type = MODEL_MODES[mode]
|
||||
self.model = self.model_type.from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=self.config,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.model_type = None
|
||||
self.model = model
|
||||
|
||||
def load_hf_checkpoint(self, *args, **kwargs):
|
||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||
|
||||
def is_logger(self):
|
||||
return self.trainer.proc_rank <= 0
|
||||
@ -138,6 +171,15 @@ class BaseTransformer(pl.LightningModule):
|
||||
),
|
||||
)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("best_tfmr")
|
||||
save_path.mkdir(exist_ok=True)
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
self.tfmr_ckpts[self.step_count] = save_path
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
parser.add_argument(
|
||||
@ -152,7 +194,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
@ -165,7 +207,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
||||
)
|
||||
@ -199,7 +241,8 @@ class LoggingCallback(pl.Callback):
|
||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
|
||||
def add_generic_args(parser, root_dir):
|
||||
def add_generic_args(parser, root_dir) -> None:
|
||||
# TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
@ -221,8 +264,8 @@ def add_generic_args(parser, root_dir):
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
|
||||
parser.add_argument("--n_gpu", type=int, default=1)
|
||||
parser.add_argument("--fast_dev_run", action="store_true")
|
||||
parser.add_argument("--gpus", type=int, default=1)
|
||||
parser.add_argument("--n_tpu_cores", type=int, default=0)
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
@ -235,28 +278,32 @@ def add_generic_args(parser, root_dir):
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
||||
parser.add_argument("--val_check_interval", default=1.0, type=float)
|
||||
|
||||
|
||||
def generic_train(model: BaseTransformer, args: argparse.Namespace):
|
||||
def generic_train(
|
||||
model: BaseTransformer,
|
||||
args: argparse.Namespace,
|
||||
early_stopping_callback=False,
|
||||
logger=True, # can pass WandbLogger() here
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
):
|
||||
# init model
|
||||
set_seed(args)
|
||||
odir = Path(model.hparams.output_dir)
|
||||
odir.mkdir(exist_ok=True)
|
||||
if checkpoint_callback is None:
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||
)
|
||||
if logging_callback is None:
|
||||
logging_callback = LoggingCallback()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
|
||||
)
|
||||
|
||||
train_params = dict(
|
||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
||||
gpus=args.n_gpu,
|
||||
max_epochs=args.num_train_epochs,
|
||||
early_stop_callback=False,
|
||||
gradient_clip_val=args.max_grad_norm,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
callbacks=[LoggingCallback()],
|
||||
)
|
||||
train_params = {}
|
||||
|
||||
if args.fp16:
|
||||
train_params["use_amp"] = args.fp16
|
||||
@ -269,12 +316,27 @@ def generic_train(model: BaseTransformer, args: argparse.Namespace):
|
||||
train_params["num_tpu_cores"] = args.n_tpu_cores
|
||||
train_params["gpus"] = 0
|
||||
|
||||
if args.n_gpu > 1:
|
||||
if args.gpus > 1:
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
|
||||
trainer = pl.Trainer(**train_params)
|
||||
trainer = pl.Trainer(
|
||||
logger=logger,
|
||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
||||
gpus=args.gpus,
|
||||
max_epochs=args.num_train_epochs,
|
||||
early_stop_callback=early_stopping_callback,
|
||||
gradient_clip_val=args.max_grad_norm,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
callbacks=[logging_callback] + extra_callbacks,
|
||||
fast_dev_run=args.fast_dev_run,
|
||||
val_check_interval=args.val_check_interval,
|
||||
weights_summary=None,
|
||||
resume_from_checkpoint=args.resume_from_checkpoint,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
if args.do_train:
|
||||
trainer.fit(model)
|
||||
|
||||
trainer.logger.log_hyperparams(args)
|
||||
trainer.logger.save()
|
||||
return trainer
|
||||
|
@ -5,5 +5,6 @@ psutil
|
||||
sacrebleu
|
||||
rouge-score
|
||||
tensorflow_datasets
|
||||
pytorch-lightning==0.7.3 # April 10, 2020 release
|
||||
pytorch-lightning==0.7.6
|
||||
matplotlib
|
||||
git-python==1.0.3
|
||||
|
@ -1,47 +1,70 @@
|
||||
### Get CNN Data
|
||||
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
|
||||
### Data
|
||||
|
||||
CNN/DailyMail data
|
||||
```bash
|
||||
cd examples/summarization
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
|
||||
tar -xzvf cnn_dm.tgz
|
||||
export CNN_DIR=${PWD}/cnn_dm
|
||||
```
|
||||
|
||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||
|
||||
XSUM Data:
|
||||
```bash
|
||||
cd examples/summarization
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
||||
tar -xzvf xsum.tar.gz
|
||||
export XSUM_DIR=${PWD}/xsum
|
||||
```
|
||||
|
||||
|
||||
### Evaluation
|
||||
|
||||
To create summaries for each article in dataset, run:
|
||||
```bash
|
||||
python evaluate_cnn.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
|
||||
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
|
||||
```
|
||||
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
|
||||
|
||||
### Training
|
||||
Run/modify `finetune_bart.sh` or `finetune_t5.sh`
|
||||
Run/modify `finetune.sh`
|
||||
|
||||
### Stanford CoreNLP Setup
|
||||
The following command should work on a 16GB GPU:
|
||||
```bash
|
||||
export me=`git config user.name`
|
||||
./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--train_batch_size=1 \
|
||||
--eval_batch_size=1 \
|
||||
--output_dir="$me"_xsum_results \
|
||||
--num_train_epochs 1
|
||||
```
|
||||
ptb_tokenize () {
|
||||
cat $1 | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > $2
|
||||
}
|
||||
|
||||
sudo apt install openjdk-8-jre-headless
|
||||
sudo apt-get install ant
|
||||
wget http://nlp.stanford.edu/software/stanford-corenlp-full-2018-10-05.zip
|
||||
unzip stanford-corenlp-full-2018-10-05.zip
|
||||
cd stanford-corenlp-full-2018-10-05
|
||||
export CLASSPATH=stanford-corenlp-3.9.2.jar:stanford-corenlp-3.9.2-models.jar
|
||||
```
|
||||
Then run `ptb_tokenize` on `test.target` and your generated hypotheses.
|
||||
### Rouge Setup
|
||||
Install `files2rouge` following the instructions at [here](https://github.com/pltrdy/files2rouge).
|
||||
I also needed to run `sudo apt-get install libxml-parser-perl`
|
||||
Tips:
|
||||
- 1 epoch at batch size 1 for bart-large takes 24 hours, requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see below)
|
||||
- `fp16_opt_level=O1` (the default works best).
|
||||
- If you are finetuning on your own dataset, start from `bart-large-cnn` if you want long summaries and `bart-large-xsum` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
|
||||
```python
|
||||
from files2rouge import files2rouge
|
||||
from files2rouge import settings
|
||||
files2rouge.run(<path_to_tokenized_hypo>,
|
||||
<path_to_tokenized_target>,
|
||||
saveto='rouge_output.txt')
|
||||
### XSUM Shared Task
|
||||
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
|
||||
Here is an example command
|
||||
```bash
|
||||
export me=`git config user.name`
|
||||
./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--output_dir "$me"_xsum_frozen_embs \
|
||||
--logger wandb_shared \
|
||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||
--num_train_epochs 6
|
||||
```
|
||||
|
||||
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_summarization/table?workspace=user-)
|
||||
|
85
examples/summarization/callbacks.py
Normal file
85
examples/summarization/callbacks.py
Normal file
@ -0,0 +1,85 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
return params
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqLoggingCallback(pl.Callback):
|
||||
def _write_logs(
|
||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||
) -> None:
|
||||
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||
if not pl_module.is_logger():
|
||||
return
|
||||
metrics = trainer.callback_metrics
|
||||
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||
# Log results
|
||||
od = Path(pl_module.hparams.output_dir)
|
||||
if type_path == "test":
|
||||
results_file = od / "test_results.txt"
|
||||
generations_file = od / "test_generations.txt"
|
||||
else:
|
||||
results_file = od / f"{type_path}_results_{trainer.global_step:05d}.txt"
|
||||
generations_file = od / f"{type_path}_generations_{trainer.global_step:05d}.txt"
|
||||
|
||||
with open(results_file, "a+") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key in ["log", "progress_bar", "preds"]:
|
||||
continue
|
||||
val = metrics[key]
|
||||
if isinstance(val, torch.Tensor):
|
||||
val = val.item()
|
||||
msg = f"{key}: {val:.6f}\n"
|
||||
writer.write(msg)
|
||||
|
||||
if not save_generations:
|
||||
return
|
||||
|
||||
if "preds" in metrics:
|
||||
content = "\n".join(metrics["preds"])
|
||||
generations_file.open("w+").write(content)
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
try:
|
||||
npars = pl_module.model.model.num_parameters()
|
||||
except AttributeError:
|
||||
npars = pl_module.model.num_parameters()
|
||||
|
||||
n_trainable_pars = count_trainable_parameters(pl_module)
|
||||
# mp stands for million parameters
|
||||
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
return self._write_logs(trainer, pl_module, "val")
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
return self._write_logs(trainer, pl_module, "test")
|
||||
|
||||
|
||||
def get_rouge2_checkpoint_callback(output_dir):
|
||||
"""Saves the best model by validation ROUGE2 score."""
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(output_dir, "{val_avg_rouge2:.4f}-{step_count}"),
|
||||
monitor="val_rouge",
|
||||
mode="max",
|
||||
save_top_k=1,
|
||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
448
examples/summarization/distillation.py
Normal file
448
examples/summarization/distillation.py
Normal file
@ -0,0 +1,448 @@
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from lightning_base import generic_train
|
||||
from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration
|
||||
|
||||
|
||||
try:
|
||||
from .finetune import SummarizationModule
|
||||
from .initialization_utils import init_student, copy_layers
|
||||
from .utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
)
|
||||
from .finetune import main as ft_main
|
||||
except ImportError:
|
||||
from finetune import SummarizationModule
|
||||
from finetune import main as ft_main
|
||||
from initialization_utils import init_student, copy_layers
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
)
|
||||
|
||||
|
||||
class SummarizationDistiller(SummarizationModule):
|
||||
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||
|
||||
def __init__(self, hparams):
|
||||
assert Path(hparams.data_dir).exists()
|
||||
|
||||
d_layers_to_copy, student, student_cfg, teacher = self.pre_init(hparams)
|
||||
|
||||
super().__init__(hparams, model=student, config=student_cfg)
|
||||
self.teacher = teacher
|
||||
use_task_specific_params(self.teacher, "summarization")
|
||||
freeze_params(self.teacher)
|
||||
self.sanity_check_gradients()
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
self.temperature = 2.0
|
||||
self.alpha_mlm = hparams.alpha_mlm
|
||||
self.alpha_ce = hparams.alpha_ce
|
||||
self.alpha_hid = hparams.alpha_hid
|
||||
# self.alpha_cos = hparams.alpha_cos
|
||||
self.alpha_encoder_loss = self.hparams.alpha_encoder_loss
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def sanity_check_gradients(self):
|
||||
assert_all_frozen(self.teacher)
|
||||
assert_all_frozen(self.model.model.decoder.embed_tokens)
|
||||
assert_all_frozen(self.model.model.encoder.embed_tokens)
|
||||
if self.different_encoder:
|
||||
assert any_requires_grad(self.model.model.encoder)
|
||||
else:
|
||||
freeze_params(self.model.model.encoder)
|
||||
del self.teacher.model.encoder
|
||||
|
||||
def pre_init(self, hparams):
|
||||
# Dump empty student model at a path, then call from_pretrained on it
|
||||
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
|
||||
student_updates = {
|
||||
"decoder_layers": hparams.student_decoder_layers,
|
||||
"encoder_layers": hparams.student_encoder_layers,
|
||||
}
|
||||
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
||||
hparams.d_layer_to_copy = d_layers_to_copy
|
||||
hparams.e_layer_to_copy = e_layers_to_copy
|
||||
kw = teacher.config.to_diff_dict()
|
||||
kw.update(student_updates)
|
||||
# Copy weights
|
||||
student_cfg = BartConfig(**kw)
|
||||
student = BartForConditionalGeneration(student_cfg)
|
||||
student, _ = init_student(student, teacher)
|
||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||
Path(hparams.output_dir).mkdir(exist_ok=True)
|
||||
return d_layers_to_copy, student, student_cfg, teacher
|
||||
|
||||
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
||||
if teacher.config.model_type == "t5":
|
||||
return self.copy_t5_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.encoder_layers
|
||||
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
|
||||
if self.different_decoder:
|
||||
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
|
||||
if self.different_encoder:
|
||||
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
|
||||
|
||||
def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
||||
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.num_layers
|
||||
self.different_decoder = hparams.student_decoder_layers != teacher.config.num_layers
|
||||
if self.different_decoder:
|
||||
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
|
||||
if self.different_encoder:
|
||||
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
||||
|
||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
|
||||
return dataset
|
||||
|
||||
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
||||
if mask is not None:
|
||||
# mask has False at padding_idx
|
||||
sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
|
||||
s_logits_slct = torch.masked_select(student_outputs, sel_mask)
|
||||
t_logits_slct = torch.masked_select(teacher_outputs, sel_mask)
|
||||
else:
|
||||
t_logits_slct = teacher_outputs
|
||||
s_logits_slct = student_outputs
|
||||
return F.mse_loss(s_logits_slct, t_logits_slct)
|
||||
|
||||
def calc_ce_loss(self, mask, s_logits, t_logits):
|
||||
if mask is not None:
|
||||
# mask has False at padding_idx
|
||||
sel_mask = mask[:, :, None].expand_as(s_logits)
|
||||
s_logits_slct = torch.masked_select(
|
||||
s_logits, sel_mask
|
||||
) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
t_logits_slct = torch.masked_select(
|
||||
t_logits, sel_mask
|
||||
) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
else:
|
||||
t_logits_slct = t_logits
|
||||
s_logits_slct = s_logits # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
assert t_logits_slct.size() == s_logits_slct.size()
|
||||
loss_ce = (
|
||||
self.ce_loss_fct(
|
||||
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||
)
|
||||
* (self.temperature) ** 2
|
||||
)
|
||||
return loss_ce, s_logits_slct, t_logits_slct
|
||||
|
||||
def configure_optimizers(self):
|
||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.hparams.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||
self.opt = optimizer
|
||||
return [optimizer]
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
SummarizationModule.add_model_specific_args(parser, root_dir)
|
||||
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
|
||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||
# parser.add_argument("--alpha_cos", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||
parser.add_argument(
|
||||
"--student_decoder_layers", default=12, type=int, required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--student_encoder_layers", default=12, type=int, required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_teacher", action="store_true", default=False,
|
||||
)
|
||||
parser.add_argument( # TODO: remove
|
||||
"--enc_only", action="store_true", default=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
def _step(self, batch):
|
||||
# assert is_frozen(self.teacher)
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
decoder_input_ids = y[:, :-1].contiguous()
|
||||
labels = y[:, 1:].clone()
|
||||
labels[y[:, 1:] == pad_token_id] = -100
|
||||
# noinspection PyCallingNonCallable
|
||||
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
labels=labels,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
)
|
||||
|
||||
def zero_tensor():
|
||||
return torch.tensor(0.0).type_as(sloss)
|
||||
|
||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||
if self.different_encoder:
|
||||
with torch.no_grad():
|
||||
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder(
|
||||
input_ids, attention_mask=src_mask, output_hidden_states=True
|
||||
)
|
||||
if self.hparams.alpha_encoder_loss > 0:
|
||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
|
||||
|
||||
hid_loss_enc = self.calc_hidden_loss(
|
||||
src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
|
||||
)
|
||||
|
||||
teacher_enc_outputs = (enc_outputs,)
|
||||
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
||||
|
||||
with torch.no_grad():
|
||||
tloss, tlogits, tdec_hidden, _ = self.teacher(
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
encoder_outputs=teacher_enc_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
lm_labels=labels,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
||||
if self.alpha_hid > 0:
|
||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||
|
||||
blended_loss = (
|
||||
self.alpha_ce * loss_ce
|
||||
+ self.alpha_mlm * sloss
|
||||
+ self.hparams.alpha_encoder_loss * loss_encoder
|
||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||
)
|
||||
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||
|
||||
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
|
||||
assert not isinstance(
|
||||
hidden_states, torch.Tensor
|
||||
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}"
|
||||
assert not isinstance(
|
||||
hidden_states_T, torch.Tensor
|
||||
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
|
||||
mask = attention_mask.to(hidden_states[0])
|
||||
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||
hidden_losses = [
|
||||
(F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum()
|
||||
/ valid_count
|
||||
for i, j in enumerate(matches)
|
||||
]
|
||||
return sum(hidden_losses)
|
||||
|
||||
|
||||
class T5SummarizationDistiller(SummarizationDistiller):
|
||||
def pre_init(self, hparams):
|
||||
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
||||
n_layer = hparams.student_decoder_layers
|
||||
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
|
||||
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
|
||||
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
|
||||
student_updates = {"num_layers": n_layer}
|
||||
hparams.d_layer_to_copy = d_layers_to_copy
|
||||
hparams.e_layer_to_copy = e_layers_to_copy
|
||||
kw = teacher.config.to_diff_dict()
|
||||
|
||||
kw.update(student_updates)
|
||||
# Copy weights
|
||||
student_cfg = T5Config(**kw)
|
||||
student = T5ForConditionalGeneration(student_cfg)
|
||||
student, _ = init_student(student, teacher)
|
||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||
Path(hparams.output_dir).mkdir(exist_ok=True)
|
||||
task_specific_params = student.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
student.config.update(task_specific_params.get("summarization", {}))
|
||||
return d_layers_to_copy, student, student_cfg, teacher
|
||||
|
||||
def freeze_embeds(self):
|
||||
freeze_params(self.model.shared)
|
||||
for d in [self.model.encoder, self.model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
def sanity_check_gradients(self):
|
||||
"""T5"""
|
||||
assert_all_frozen(self.teacher)
|
||||
assert_all_frozen(self.model.decoder.embed_tokens)
|
||||
assert_all_frozen(self.model.encoder.embed_tokens)
|
||||
if self.different_encoder:
|
||||
assert any_requires_grad(self.model.encoder)
|
||||
else:
|
||||
freeze_params(self.model.encoder)
|
||||
del self.teacher.model.encoder
|
||||
if self.different_decoder:
|
||||
assert any_requires_grad(self.model.decoder)
|
||||
else:
|
||||
freeze_params(self.model.decoder) # TODO(SS): very suspicious
|
||||
|
||||
def _step(self, batch):
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
decoder_input_ids = y[:, :-1].contiguous()
|
||||
labels = y[:, 1:].clone()
|
||||
labels[y[:, 1:] == pad_token_id] = -100
|
||||
# noinspection PyCallingNonCallable
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
|
||||
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
labels=labels,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
def zero_tensor():
|
||||
return torch.tensor(0.0).type_as(sloss)
|
||||
|
||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||
if self.different_encoder:
|
||||
with torch.no_grad():
|
||||
teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder(
|
||||
source_ids, attention_mask=source_mask, output_hidden_states=True, use_cache=False,
|
||||
)
|
||||
if self.hparams.alpha_encoder_loss > 0:
|
||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask)
|
||||
|
||||
hid_loss_enc = self.calc_hidden_loss(
|
||||
source_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
|
||||
)
|
||||
|
||||
teacher_enc_outputs = (enc_outputs,)
|
||||
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
||||
|
||||
with torch.no_grad():
|
||||
tloss, tlogits, tdec_hidden, _ = self.teacher(
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
encoder_outputs=teacher_enc_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
lm_labels=labels,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
||||
if self.alpha_hid > 0:
|
||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||
|
||||
blended_loss = (
|
||||
self.alpha_ce * loss_ce
|
||||
+ self.alpha_mlm * sloss
|
||||
+ self.hparams.alpha_encoder_loss * loss_encoder
|
||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||
)
|
||||
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||
|
||||
|
||||
def create_module(args):
|
||||
t5 = "t5" in args.model_name_or_path
|
||||
if args.no_teacher:
|
||||
assert not args.enc_only
|
||||
module_cls = SummarizationModule
|
||||
elif t5:
|
||||
module_cls = T5SummarizationDistiller
|
||||
elif args.enc_only:
|
||||
raise ValueError("Deleted that")
|
||||
else:
|
||||
module_cls = SummarizationDistiller
|
||||
args.setup_cls: str = module_cls.__name__
|
||||
model = module_cls(args)
|
||||
return model
|
||||
|
||||
|
||||
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
||||
exp_dir = ckpt_path.parent
|
||||
if dest_dir is None:
|
||||
dest_dir = exp_dir
|
||||
clash = list(dest_dir.glob("test_generations*"))
|
||||
if clash:
|
||||
print(f"SKIPPING to avoid overwriting {clash}")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
if "hparams" in ckpt:
|
||||
args = argparse.Namespace(**ckpt["hparams"])
|
||||
else:
|
||||
args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl"))
|
||||
args.resume_from_checkpoint = str(ckpt_path)
|
||||
args.do_train = False
|
||||
args.output_dir = str(dest_dir)
|
||||
args.n_gpu = 1
|
||||
args.eval_batch_size = 16
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
model = create_module(args)
|
||||
trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False)
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
def get_layers_to_copy(n_to_get, tot):
|
||||
all_layers = list(range(tot))
|
||||
if tot == 12: # Alternating for special cases
|
||||
layers_to_copy = { # maps # layers in student -> which teacher layers to copy
|
||||
6: [0, 2, 4, 7, 9, 11],
|
||||
1: [11],
|
||||
3: [0, 6, 11],
|
||||
2: [0, 11],
|
||||
4: [0, 4, 8, 11],
|
||||
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||
12: all_layers,
|
||||
}
|
||||
return layers_to_copy[n_to_get]
|
||||
else:
|
||||
return all_layers[:n_to_get]
|
||||
|
||||
|
||||
def distill_main(args):
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
|
||||
model = create_module(args)
|
||||
return ft_main(args, model=model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
|
||||
distill_main(args)
|
@ -1,100 +0,0 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoModelWithLMHead, AutoTokenizer
|
||||
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def generate_summaries(
|
||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
|
||||
):
|
||||
fout = Path(out_file).open("w", encoding="utf-8")
|
||||
model = AutoModelWithLMHead.from_pretrained(model_name).to(device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# update config with summarization specific params
|
||||
task_specific_params = model.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
model.config.update(task_specific_params.get("summarization", {}))
|
||||
|
||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||
if "t5" in model_name:
|
||||
batch = [model.config.prefix + text for text in batch]
|
||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
|
||||
device
|
||||
)
|
||||
summaries = model.generate(**dct)
|
||||
|
||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for hypothesis in dec:
|
||||
fout.write(hypothesis + "\n")
|
||||
fout.flush()
|
||||
|
||||
|
||||
def calculate_rouge(output_lns, reference_lns, score_path):
|
||||
score_file = Path(score_path).open("w")
|
||||
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
|
||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||
scores = scorer.score(reference_ln, output_ln)
|
||||
aggregator.add_scores(scores)
|
||||
|
||||
result = aggregator.aggregate()
|
||||
score_file.write(
|
||||
"ROUGE_1: \n{} \n\n ROUGE_2: \n{} \n\n ROUGE_L: \n{} \n\n".format(
|
||||
result["rouge1"], result["rouge2"], result["rougeL"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_generate():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"input_path", type=str, help="like cnn_dm/test.source or cnn_dm/test_articles_input.txt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path", type=str, help="where to save summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_name",
|
||||
type=str,
|
||||
default="facebook/bart-large-cnn",
|
||||
help="like bart-large-cnn,'t5-small', 't5-base', 't5-large', 't5-3b', 't5-11b",
|
||||
)
|
||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
||||
parser.add_argument(
|
||||
"--score_path", type=str, required=False, help="where to save the rouge score",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||
|
||||
generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
|
||||
if args.score_path is not None:
|
||||
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
||||
|
||||
calculate_rouge(output_lns, reference_lns, args.score_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_generate()
|
@ -3,91 +3,169 @@ import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
try:
|
||||
from .utils import SummarizationDataset
|
||||
from .utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
)
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
||||
except ImportError:
|
||||
from utils import SummarizationDataset
|
||||
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummarizationTrainer(BaseTransformer):
|
||||
class SummarizationModule(BaseTransformer):
|
||||
mode = "summarization"
|
||||
loss_names = ["loss"]
|
||||
|
||||
mode = "language-modeling"
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||
use_task_specific_params(self.model, "summarization")
|
||||
save_git_info(self.hparams.output_dir)
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
|
||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||
self.step_count = 0
|
||||
self.metrics = {"train": [], "val": [], "test": []}
|
||||
|
||||
def __init__(self, hparams):
|
||||
super().__init__(hparams, num_labels=None, mode=self.mode)
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
max_source_length=self.hparams.max_source_length,
|
||||
max_target_length=self.hparams.max_target_length,
|
||||
prefix=self.model.config.prefix or "",
|
||||
)
|
||||
n_observations_per_split = {
|
||||
"train": self.hparams.n_train,
|
||||
"val": self.hparams.n_val,
|
||||
"test": self.hparams.n_test,
|
||||
}
|
||||
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None):
|
||||
return self.model(
|
||||
input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, lm_labels=lm_labels,
|
||||
self.target_lens = {
|
||||
"train": self.hparams.max_target_length,
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
if self.hparams.freeze_embeds:
|
||||
self.freeze_embeds()
|
||||
if self.hparams.freeze_encoder:
|
||||
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
if self.model.config.model_type == "bart":
|
||||
freeze_params(self.model.model.shared)
|
||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
else:
|
||||
freeze_params(self.model.shared)
|
||||
for d in [self.model.encoder, self.model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
def ids_to_clean_text(self, generated_ids: List[int]):
|
||||
gen_text = self.tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return lmap(str.strip, gen_text)
|
||||
|
||||
def _step(self, batch):
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"]
|
||||
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
y_ids = y[:, :-1].contiguous()
|
||||
lm_labels = y[:, 1:].clone()
|
||||
lm_labels[y[:, 1:] == pad_token_id] = -100
|
||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,)
|
||||
|
||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
|
||||
loss = outputs[0]
|
||||
return (loss,)
|
||||
|
||||
return loss
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
loss_tensors = self._step(batch)
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self._step(batch)
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
return self._generative_step(batch)
|
||||
|
||||
tensorboard_logs = {"train_loss": loss}
|
||||
return {"loss": loss, "log": tensorboard_logs}
|
||||
def validation_end(self, outputs, prefix="val") -> Dict:
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
|
||||
rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
|
||||
rouges.update({k: v.item() for k, v in losses.items()})
|
||||
losses.update(rouges)
|
||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
metrics["step_count"] = self.step_count
|
||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss = self._step(batch)
|
||||
return {"val_loss": loss}
|
||||
def save_metrics(self, metrics, prefix) -> None:
|
||||
self.metrics[prefix].append(metrics)
|
||||
pickle_save(self.metrics, self.metrics_save_path)
|
||||
|
||||
def validation_end(self, outputs):
|
||||
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
||||
tensorboard_logs = {"val_loss": avg_loss}
|
||||
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
def _generative_step(self, batch):
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
# NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
|
||||
generated_ids = self.model.generate(
|
||||
input_ids=source_ids,
|
||||
attention_mask=source_mask,
|
||||
num_beams=1,
|
||||
max_length=80,
|
||||
repetition_penalty=2.5,
|
||||
length_penalty=1.0,
|
||||
early_stopping=True,
|
||||
use_cache=True,
|
||||
)
|
||||
preds = [
|
||||
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
for g in generated_ids
|
||||
]
|
||||
target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]
|
||||
loss = self._step(batch)
|
||||
# TODO(SS): task specific params
|
||||
|
||||
return {"val_loss": loss, "preds": preds, "target": target}
|
||||
t0 = time.time()
|
||||
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
||||
gen_time = time.time() - t0
|
||||
preds = self.ids_to_clean_text(generated_ids)
|
||||
target = self.ids_to_clean_text(y)
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
rouge: Dict = calculate_rouge(preds, target)
|
||||
summ_len = np.mean(lmap(len, generated_ids))
|
||||
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
|
||||
return base_metrics
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def test_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
return self.validation_end(outputs, prefix="test")
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
||||
@ -102,15 +180,43 @@ class SummarizationTrainer(BaseTransformer):
|
||||
|
||||
return self.test_end(outputs)
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
self.validation_end(outputs, "val")
|
||||
|
||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = SummarizationDataset(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
max_target_length=max_target_length,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle)
|
||||
dataset = self.get_dataset(type_path)
|
||||
sampler = None
|
||||
if self.hparams.sortish_sampler and type_path == "train":
|
||||
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
||||
sampler = dataset.make_sortish_sampler(batch_size)
|
||||
shuffle = False
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.gradient_accumulation_steps
|
||||
* float(self.hparams.num_train_epochs)
|
||||
)
|
||||
@ -129,7 +235,7 @@ class SummarizationTrainer(BaseTransformer):
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
# Add BART specific options
|
||||
add_generic_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_source_length",
|
||||
default=1024,
|
||||
@ -144,41 +250,82 @@ class SummarizationTrainer(BaseTransformer):
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=142,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.",
|
||||
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
|
||||
)
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
def main(args, model=None) -> SummarizationModule:
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if model is None:
|
||||
model: BaseTransformer = SummarizationModule(args)
|
||||
if (
|
||||
args.logger == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
# If output_dir not provided, a folder will be generated in pwd
|
||||
if not args.output_dir:
|
||||
args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
||||
os.makedirs(args.output_dir)
|
||||
model = SummarizationTrainer(args)
|
||||
trainer = generic_train(model, args)
|
||||
logger = WandbLogger(name=model.output_dir.name)
|
||||
elif args.logger == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
# Optionally, predict on dev set and write to output_dir
|
||||
if args.do_predict:
|
||||
# See https://github.com/huggingface/transformers/issues/3159
|
||||
# pl use this format to create a checkpoint:
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
||||
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
||||
model = model.load_from_checkpoint(checkpoints[-1])
|
||||
trainer.test(model)
|
||||
# TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
|
||||
logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
|
||||
logger=logger,
|
||||
# TODO: early stopping callback seems messed up
|
||||
)
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
model.hparams.test_checkpoint = ""
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
||||
if checkpoints:
|
||||
model.hparams.test_checkpoint = checkpoints[-1]
|
||||
trainer.resume_from_checkpoint = checkpoints[-1]
|
||||
trainer.logger.log_hyperparams(model.hparams)
|
||||
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
add_generic_args(parser, os.getcwd())
|
||||
parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
23
examples/summarization/finetune.sh
Executable file
23
examples/summarization/finetune.sh
Executable file
@ -0,0 +1,23 @@
|
||||
export OUTPUT_DIR=bart_cnn_finetune
|
||||
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
|
||||
# --model_name_or_path=t5-base for t5
|
||||
|
||||
python finetune.py \
|
||||
--model_name_or_path=facebook/bart-large \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val 1000 \
|
||||
--val_check_interval 0.1 \
|
||||
--sortish_sampler \
|
||||
--max_target_length=56 \
|
||||
$@
|
@ -1,18 +0,0 @@
|
||||
export OUTPUT_DIR_NAME=bart_sum
|
||||
export CURRENT_DIR=${PWD}
|
||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
||||
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python finetune.py \
|
||||
--data_dir=./cnn-dailymail/cnn_dm \
|
||||
--model_name_or_path=bart-large \
|
||||
--learning_rate=3e-5 \
|
||||
--train_batch_size=4 \
|
||||
--eval_batch_size=4 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_train $@
|
20
examples/summarization/initialization_utils.py
Normal file
20
examples/summarization/initialization_utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import List
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def init_student(student, teacher):
|
||||
teacher_state_dict = teacher.state_dict()
|
||||
info = student.load_state_dict(teacher_state_dict, strict=False)
|
||||
assert info.missing_keys == [], info.missing_keys
|
||||
return student, info
|
||||
|
||||
|
||||
def copy_decoder_layers(teacher, student, l2copy=[0, 2, 4, 7, 9, 11]):
|
||||
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, l2copy)
|
||||
|
||||
|
||||
def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList, layers_to_copy: List) -> None:
|
||||
layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in layers_to_copy])
|
||||
assert len(student_layers) == len(layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}"
|
||||
student_layers.load_state_dict(layers_to_copy.state_dict())
|
12
examples/summarization/run_distiller.sh
Executable file
12
examples/summarization/run_distiller.sh
Executable file
@ -0,0 +1,12 @@
|
||||
#CNN_DIR = /home/shleifer/transformers_fork/examples/summarization/bart/cnn_dm
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python distillation.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 \
|
||||
--val_check_interval 0.1 \
|
||||
$@
|
78
examples/summarization/run_eval.py
Normal file
78
examples/summarization/run_eval.py
Normal file
@ -0,0 +1,78 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from .finetune import calculate_rouge, use_task_specific_params
|
||||
except ImportError:
|
||||
from finetune import calculate_rouge, use_task_specific_params
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def generate_summaries(
|
||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
|
||||
) -> None:
|
||||
fout = Path(out_file).open("w", encoding="utf-8")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# update config with summarization specific params
|
||||
use_task_specific_params(model, "summarization")
|
||||
|
||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||
if "t5" in model_name:
|
||||
batch = [model.config.prefix + text for text in batch]
|
||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
|
||||
device
|
||||
)
|
||||
summaries = model.generate(**dct)
|
||||
|
||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for hypothesis in dec:
|
||||
fout.write(hypothesis + "\n")
|
||||
fout.flush()
|
||||
|
||||
|
||||
def run_generate():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
|
||||
parser.add_argument("output_path", type=str, help="where to save summaries")
|
||||
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
||||
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
|
||||
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
args = parser.parse_args()
|
||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||
|
||||
generate_summaries(
|
||||
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
|
||||
)
|
||||
if args.score_path is not None:
|
||||
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
||||
|
||||
rouge: dict = calculate_rouge(output_lns, reference_lns)
|
||||
|
||||
json.dump(rouge, open("score_path", "w+"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_generate()
|
@ -7,28 +7,40 @@ import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import BartTokenizer
|
||||
|
||||
from .evaluate_cnn import run_generate
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
from .utils import SummarizationDataset
|
||||
from .run_eval import generate_summaries, run_generate
|
||||
from .utils import SummarizationDataset, lmap, pickle_load
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
DEFAULT_ARGS = {
|
||||
FP16_EVER = False
|
||||
CHEAP_ARGS = {
|
||||
"logger": "default",
|
||||
"alpha_hid": 0,
|
||||
"freeze_embeds": True,
|
||||
"enc_only": False,
|
||||
"tgt_suffix": "",
|
||||
"resume_from_checkpoint": None,
|
||||
"sortish_sampler": True,
|
||||
"student_decoder_layers": 1,
|
||||
"val_check_interval": 1.0,
|
||||
"output_dir": "",
|
||||
"fp16": False,
|
||||
"no_teacher": False,
|
||||
"fp16_opt_level": "O1",
|
||||
"n_gpu": 1,
|
||||
"gpus": 1 if torch.cuda.is_available() else 0,
|
||||
"n_tpu_cores": 0,
|
||||
"max_grad_norm": 1.0,
|
||||
"do_train": True,
|
||||
"do_predict": False,
|
||||
"do_predict": True,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"server_ip": "",
|
||||
"server_port": "",
|
||||
@ -36,7 +48,7 @@ DEFAULT_ARGS = {
|
||||
"model_type": "bart",
|
||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
||||
"config_name": "",
|
||||
"tokenizer_name": "",
|
||||
"tokenizer_name": "facebook/bart-large",
|
||||
"cache_dir": "",
|
||||
"do_lower_case": False,
|
||||
"learning_rate": 3e-05,
|
||||
@ -48,6 +60,17 @@ DEFAULT_ARGS = {
|
||||
"eval_batch_size": 2,
|
||||
"max_source_length": 12,
|
||||
"max_target_length": 12,
|
||||
"val_max_target_length": 12,
|
||||
"test_max_target_length": 12,
|
||||
"fast_dev_run": False,
|
||||
"no_cache": False,
|
||||
"n_train": -1,
|
||||
"n_val": -1,
|
||||
"n_test": -1,
|
||||
"student_encoder_layers": 1,
|
||||
"alpha_loss_encoder": 0.0,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
}
|
||||
|
||||
|
||||
@ -56,6 +79,9 @@ def _dump_articles(path: Path, articles: list):
|
||||
f.write("\n".join(articles))
|
||||
|
||||
|
||||
BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute()
|
||||
|
||||
|
||||
def make_test_data_dir():
|
||||
tmp_dir = Path(tempfile.gettempdir())
|
||||
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||
@ -66,6 +92,169 @@ def make_test_data_dir():
|
||||
return tmp_dir
|
||||
|
||||
|
||||
@unittest.skip("These wont' pass until hidden_states kwarg is merged.")
|
||||
class TestSummarizationDistiller(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
return cls
|
||||
|
||||
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
|
||||
def test_bdc_multigpu(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
no_teacher=True,
|
||||
freeze_encoder=True,
|
||||
gpus=2,
|
||||
sortish_sampler=False,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||
def test_bdc_fp16(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=3.0,
|
||||
freeze_encoder=True,
|
||||
gpus=1,
|
||||
fp16=FP16_EVER,
|
||||
fp16_opt_level="O1",
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||
def test_bdc_t5_eval_fp16(self):
|
||||
updates = dict(
|
||||
fp16=FP16_EVER,
|
||||
gpus=1,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
do_train=False,
|
||||
do_predict=True,
|
||||
tokenizer_name=None,
|
||||
no_teacher=True,
|
||||
)
|
||||
self._bart_distiller_cli(updates, check_contents=False)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||
def test_bdc_t5_train_fp16(self):
|
||||
updates = dict(
|
||||
fp16=FP16_EVER,
|
||||
gpus=1,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
do_train=True,
|
||||
do_predict=True,
|
||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||
no_teacher=True,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_no_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_yes_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1,)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_checkpointing(self):
|
||||
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
num_train_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
)
|
||||
model = self._bart_distiller_cli(updates, check_contents=False)
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
self.assertEqual(len(transformer_ckpts), len(ckpts))
|
||||
new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
self.assertEqual(len(new_transformer_ckpts), 1)
|
||||
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
||||
out_path = tempfile.mktemp()
|
||||
generate_summaries(examples, out_path, new_transformer_ckpts[0].parent)
|
||||
self.assertTrue(Path(out_path).exists())
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
def test_bdc_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=2.0,
|
||||
teacher="patrickvonplaten/t5-tiny-random",
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_t5_eval(self):
|
||||
updates = dict(
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
do_train=False,
|
||||
do_predict=True,
|
||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||
no_teacher=True,
|
||||
)
|
||||
self._bart_distiller_cli(updates, check_contents=False)
|
||||
|
||||
def _bart_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
model_type="bart",
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
num_train_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
gpus=1 if torch.cuda.is_available() else 0,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
alpha_encoder_loss=0.4,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
|
||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||
model = distill_main(argparse.Namespace(**args_d))
|
||||
if not check_contents:
|
||||
return model
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
self.assertIn(ckpt_name, contents)
|
||||
self.assertIn("metrics.pkl", contents)
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("val_generations_1.txt", contents)
|
||||
self.assertIn("val_1_results.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
# self.assertEqual(len(contents), 15)
|
||||
|
||||
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
|
||||
import pandas as pd
|
||||
|
||||
val_df = pd.DataFrame(metrics["val"])
|
||||
train_df = pd.DataFrame(metrics["train"])
|
||||
test_df = pd.DataFrame(metrics["test"])
|
||||
desired_n_evals = args_d["num_train_epochs"] * 2 + 1
|
||||
self.assertEqual(val_df.shape[0], desired_n_evals) #
|
||||
self.assertEqual(test_df.shape[1], val_df.shape[1])
|
||||
self.assertEqual(train_df.shape[0], 0)
|
||||
return model
|
||||
|
||||
|
||||
class TestBartExamples(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -79,49 +268,31 @@ class TestBartExamples(unittest.TestCase):
|
||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
_dump_articles(tmp, articles)
|
||||
testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
|
||||
testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
self.assertTrue(Path(output_file_name).exists())
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
def test_bart_run_sum_cli(self):
|
||||
args_d: dict = DEFAULT_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
|
||||
)
|
||||
main(argparse.Namespace(**args_d))
|
||||
args_d.update({"do_train": False, "do_predict": True})
|
||||
|
||||
main(argparse.Namespace(**args_d))
|
||||
contents = os.listdir(output_dir)
|
||||
expected_contents = {
|
||||
"checkpointepoch=0.ckpt",
|
||||
"test_results.txt",
|
||||
}
|
||||
created_files = {os.path.basename(p) for p in contents}
|
||||
self.assertSetEqual(expected_contents, created_files)
|
||||
|
||||
def test_t5_run_sum_cli(self):
|
||||
args_d: dict = DEFAULT_ARGS.copy()
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
tokenizer_name=None, # "patrickvonplaten/t5-tiny-random",
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
n_gpu=0,
|
||||
gpus=0,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
)
|
||||
main(argparse.Namespace(**args_d))
|
||||
|
||||
# args_d.update({"do_train": False, "do_predict": True})
|
||||
# main(argparse.Namespace(**args_d))
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
main(args)
|
||||
|
||||
def test_bart_summarization_dataset(self):
|
||||
tmp_dir = Path(tempfile.gettempdir())
|
||||
@ -138,42 +309,16 @@ class TestBartExamples(unittest.TestCase):
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
self.assertEqual(batch["source_mask"].shape, batch["source_ids"].shape)
|
||||
self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape)
|
||||
# show that articles were trimmed.
|
||||
self.assertEqual(batch["source_ids"].shape[1], max_len_source)
|
||||
self.assertGreater(20, batch["source_ids"].shape[1]) # trimmed significantly
|
||||
self.assertEqual(batch["input_ids"].shape[1], max_len_source)
|
||||
self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly
|
||||
|
||||
# show that targets were truncated
|
||||
self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated
|
||||
self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated
|
||||
self.assertGreater(max_len_target, trunc_target) # Truncated
|
||||
|
||||
|
||||
class TestT5Examples(unittest.TestCase):
|
||||
def test_t5_cli(self):
|
||||
output_file_name = "output_t5_sum.txt"
|
||||
score_file_name = "score_t5_sum.txt"
|
||||
articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
|
||||
with tmp.open("w", encoding="utf-8") as f:
|
||||
f.write("\n".join(articles))
|
||||
|
||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
|
||||
score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"
|
||||
|
||||
testargs = [
|
||||
"evaluate_cnn.py",
|
||||
str(tmp),
|
||||
str(output_file_name),
|
||||
"patrickvonplaten/t5-tiny-random",
|
||||
"--reference_path",
|
||||
str(tmp),
|
||||
"--score_path",
|
||||
str(score_file_name),
|
||||
]
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
self.assertTrue(Path(output_file_name).exists())
|
||||
self.assertTrue(Path(score_file_name).exists())
|
||||
def list_to_text_file(lst, path):
|
||||
dest = Path(path)
|
||||
dest.open("w+").writelines(lst)
|
||||
|
@ -1,20 +1,66 @@
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartTokenizer
|
||||
|
||||
|
||||
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||
def encode_file(
|
||||
tokenizer,
|
||||
data_path,
|
||||
max_length,
|
||||
pad_to_max_length=True,
|
||||
return_tensors="pt",
|
||||
overwrite_cache=False,
|
||||
prefix="",
|
||||
tok_name="",
|
||||
):
|
||||
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
||||
if not overwrite_cache and cache_path.exists():
|
||||
try:
|
||||
examples = torch.load(cache_path)
|
||||
assert isinstance(examples, list)
|
||||
return examples
|
||||
|
||||
except Exception:
|
||||
print(f"failed to load from {cache_path}, retokenizing {data_path}")
|
||||
data_path = Path(data_path)
|
||||
|
||||
lns = lmap(str.strip, data_path.open().readlines())
|
||||
lns = [prefix + text for text in lns]
|
||||
assert lns, f"found empty file at {data_path}"
|
||||
examples = []
|
||||
with open(data_path, "r") as f:
|
||||
for text in f.readlines():
|
||||
tokenized = tokenizer.batch_encode_plus(
|
||||
[text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
|
||||
)
|
||||
examples.append(tokenized)
|
||||
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
|
||||
tokenized = tokenizer.batch_encode_plus(
|
||||
[text], # DONT ADD SPACES
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
add_prefix_space=True,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
examples.append(tokenized)
|
||||
torch.save(lmap(dict, examples), cache_path.open("wb"))
|
||||
return examples
|
||||
|
||||
|
||||
def lmap(f, x):
|
||||
return list(map(f, x))
|
||||
|
||||
|
||||
T5_PREFIX = "summarize: " # HACK, fixme
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids, pad_token_id, attention_mask=None,
|
||||
):
|
||||
@ -30,15 +76,38 @@ class SummarizationDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
data_dir="./cnn-dailymail/cnn_dm/",
|
||||
data_dir,
|
||||
type_path="train",
|
||||
max_source_length=1024,
|
||||
max_target_length=56,
|
||||
n_obs=None,
|
||||
overwrite_cache=False,
|
||||
prefix="",
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else ""
|
||||
self.source = encode_file(
|
||||
tokenizer,
|
||||
os.path.join(data_dir, type_path + ".source"),
|
||||
max_source_length,
|
||||
overwrite_cache=overwrite_cache,
|
||||
prefix=prefix,
|
||||
tok_name=tok_name,
|
||||
)
|
||||
if type_path == "train":
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
else:
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
|
||||
self.target = encode_file(
|
||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||
)
|
||||
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
|
||||
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
|
||||
if n_obs is not None:
|
||||
self.source = self.source[:n_obs]
|
||||
self.target = self.target[:n_obs]
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
def __len__(self):
|
||||
return len(self.source)
|
||||
@ -47,19 +116,141 @@ class SummarizationDataset(Dataset):
|
||||
source_ids = self.source[index]["input_ids"].squeeze()
|
||||
target_ids = self.target[index]["input_ids"].squeeze()
|
||||
src_mask = self.source[index]["attention_mask"].squeeze()
|
||||
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
|
||||
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
|
||||
|
||||
@staticmethod
|
||||
def trim_seq2seq_batch(batch, pad_token_id):
|
||||
y = trim_batch(batch["target_ids"], pad_token_id)
|
||||
source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"])
|
||||
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
|
||||
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
|
||||
return source_ids, source_mask, y
|
||||
|
||||
def collate_fn(self, batch):
|
||||
input_ids = torch.stack([x["source_ids"] for x in batch])
|
||||
masks = torch.stack([x["source_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["target_ids"] for x in batch])
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
def collate_fn(self, batch) -> dict:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||
pad_token_id = self.pad_token_id
|
||||
y = trim_batch(target_ids, pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||
return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}
|
||||
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
|
||||
return batch
|
||||
|
||||
@property
|
||||
def src_lens(self): # Can delete?
|
||||
return lmap(len, self.source)
|
||||
|
||||
@property
|
||||
def tgt_lens(self):
|
||||
return lmap(len, self.target)
|
||||
|
||||
def make_sortish_sampler(self, batch_size):
|
||||
return SortishSampler(self.source, batch_size)
|
||||
|
||||
|
||||
class SortishSampler(Sampler):
|
||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||
|
||||
def __init__(self, data, batch_size):
|
||||
self.data, self.bs = data, batch_size
|
||||
|
||||
def key(self, i):
|
||||
return len(self.data[i])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __iter__(self):
|
||||
idxs = np.random.permutation(len(self.data))
|
||||
sz = self.bs * 50
|
||||
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
||||
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
|
||||
sz = self.bs
|
||||
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
||||
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
|
||||
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
|
||||
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
|
||||
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
||||
return iter(sort_idx)
|
||||
|
||||
|
||||
def use_task_specific_params(model, task):
|
||||
# update config with summarization specific params
|
||||
task_specific_params = model.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
model.config.update(task_specific_params.get(task, {}))
|
||||
|
||||
|
||||
def pickle_load(path):
|
||||
"""pickle.load(path)"""
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def pickle_save(obj, path):
|
||||
"""pickle.dump(obj, path)"""
|
||||
with open(path, "wb") as f:
|
||||
return pickle.dump(obj, f)
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||
|
||||
|
||||
def save_git_info(folder_path: str):
|
||||
"""
|
||||
Log commit info.
|
||||
"""
|
||||
repo_infos = get_git_info()
|
||||
|
||||
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
|
||||
json.dump(repo_infos, f, indent=4)
|
||||
|
||||
|
||||
def get_git_info():
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
}
|
||||
return repo_infos
|
||||
|
||||
|
||||
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
||||
|
||||
|
||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
|
||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||
scores = scorer.score(reference_ln, output_ln)
|
||||
aggregator.add_scores(scores)
|
||||
|
||||
result = aggregator.aggregate()
|
||||
return {k: v.mid.fmeasure for k, v in result.items()}
|
||||
|
||||
|
||||
def freeze_params(model: nn.Module):
|
||||
for par in model.parameters():
|
||||
par.requires_grad = False
|
||||
|
||||
|
||||
def grad_status(model: nn.Module) -> Iterable:
|
||||
return (par.requires_grad for par in model.parameters())
|
||||
|
||||
|
||||
def any_requires_grad(model: nn.Module) -> bool:
|
||||
return any(grad_status(model))
|
||||
|
||||
|
||||
def assert_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
n_require_grad = sum(lmap(int, model_grads))
|
||||
npars = len(model_grads)
|
||||
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
|
||||
|
||||
|
||||
def assert_not_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
npars = len(model_grads)
|
||||
assert any(model_grads), f"none of {npars} weights require grad"
|
||||
|
@ -59,7 +59,7 @@ BART_GENERATION_EXAMPLE = r"""
|
||||
Examples::
|
||||
|
||||
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
|
||||
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
|
||||
# see ``examples/summarization/bart/run_eval.py`` for a longer example
|
||||
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
|
||||
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
|
||||
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||
|
Loading…
Reference in New Issue
Block a user