mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[examples] SummarizationModule improvements (#4951)
This commit is contained in:
parent
cd40f6564e
commit
043f9f51f9
@ -2,6 +2,8 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
@ -13,10 +15,13 @@ from transformers import (
|
|||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForPreTraining,
|
AutoModelForPreTraining,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
PretrainedConfig,
|
||||||
|
PreTrainedTokenizer,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,6 +36,8 @@ MODEL_MODES = {
|
|||||||
"pretraining": AutoModelForPreTraining,
|
"pretraining": AutoModelForPreTraining,
|
||||||
"token-classification": AutoModelForTokenClassification,
|
"token-classification": AutoModelForTokenClassification,
|
||||||
"language-modeling": AutoModelWithLMHead,
|
"language-modeling": AutoModelWithLMHead,
|
||||||
|
"summarization": AutoModelForSeq2SeqLM,
|
||||||
|
"translation": AutoModelForSeq2SeqLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -38,33 +45,59 @@ def set_seed(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
if args.n_gpu > 0:
|
if args.gpus > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
class BaseTransformer(pl.LightningModule):
|
class BaseTransformer(pl.LightningModule):
|
||||||
def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hparams: argparse.Namespace,
|
||||||
|
num_labels=None,
|
||||||
|
mode="base",
|
||||||
|
config=None,
|
||||||
|
tokenizer=None,
|
||||||
|
model=None,
|
||||||
|
**config_kwargs
|
||||||
|
):
|
||||||
"Initialize a model."
|
"Initialize a model."
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hparams = hparams
|
self.hparams = hparams
|
||||||
|
self.step_count = 0
|
||||||
|
self.tfmr_ckpts = {}
|
||||||
|
self.output_dir = Path(self.hparams.output_dir)
|
||||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||||
|
if config is None:
|
||||||
self.config = AutoConfig.from_pretrained(
|
self.config = AutoConfig.from_pretrained(
|
||||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
**config_kwargs,
|
**config_kwargs,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.config: PretrainedConfig = config
|
||||||
|
if tokenizer is None:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
self.model = MODEL_MODES[mode].from_pretrained(
|
else:
|
||||||
|
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||||
|
if model is None:
|
||||||
|
self.model_type = MODEL_MODES[mode]
|
||||||
|
self.model = self.model_type.from_pretrained(
|
||||||
self.hparams.model_name_or_path,
|
self.hparams.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||||
config=self.config,
|
config=self.config,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.model_type = None
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def load_hf_checkpoint(self, *args, **kwargs):
|
||||||
|
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
def is_logger(self):
|
def is_logger(self):
|
||||||
return self.trainer.proc_rank <= 0
|
return self.trainer.proc_rank <= 0
|
||||||
@ -138,6 +171,15 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pl.utilities.rank_zero_only
|
||||||
|
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||||
|
save_path = self.output_dir.joinpath("best_tfmr")
|
||||||
|
save_path.mkdir(exist_ok=True)
|
||||||
|
self.model.config.save_step = self.step_count
|
||||||
|
self.model.save_pretrained(save_path)
|
||||||
|
self.tokenizer.save_pretrained(save_path)
|
||||||
|
self.tfmr_ckpts[self.step_count] = save_path
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -152,7 +194,7 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer_name",
|
"--tokenizer_name",
|
||||||
default="",
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
)
|
)
|
||||||
@ -165,7 +207,7 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
||||||
)
|
)
|
||||||
@ -199,7 +241,8 @@ class LoggingCallback(pl.Callback):
|
|||||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||||
|
|
||||||
|
|
||||||
def add_generic_args(parser, root_dir):
|
def add_generic_args(parser, root_dir) -> None:
|
||||||
|
# TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_dir",
|
"--output_dir",
|
||||||
default=None,
|
default=None,
|
||||||
@ -221,8 +264,8 @@ def add_generic_args(parser, root_dir):
|
|||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--fast_dev_run", action="store_true")
|
||||||
parser.add_argument("--n_gpu", type=int, default=1)
|
parser.add_argument("--gpus", type=int, default=1)
|
||||||
parser.add_argument("--n_tpu_cores", type=int, default=0)
|
parser.add_argument("--n_tpu_cores", type=int, default=0)
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
@ -235,28 +278,32 @@ def add_generic_args(parser, root_dir):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
||||||
|
parser.add_argument("--val_check_interval", default=1.0, type=float)
|
||||||
|
|
||||||
|
|
||||||
def generic_train(model: BaseTransformer, args: argparse.Namespace):
|
def generic_train(
|
||||||
|
model: BaseTransformer,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
early_stopping_callback=False,
|
||||||
|
logger=True, # can pass WandbLogger() here
|
||||||
|
extra_callbacks=[],
|
||||||
|
checkpoint_callback=None,
|
||||||
|
logging_callback=None,
|
||||||
|
**extra_train_kwargs
|
||||||
|
):
|
||||||
# init model
|
# init model
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
odir = Path(model.hparams.output_dir)
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
odir.mkdir(exist_ok=True)
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
if checkpoint_callback is None:
|
||||||
|
|
||||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
|
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||||
)
|
)
|
||||||
|
if logging_callback is None:
|
||||||
|
logging_callback = LoggingCallback()
|
||||||
|
|
||||||
train_params = dict(
|
train_params = {}
|
||||||
accumulate_grad_batches=args.gradient_accumulation_steps,
|
|
||||||
gpus=args.n_gpu,
|
|
||||||
max_epochs=args.num_train_epochs,
|
|
||||||
early_stop_callback=False,
|
|
||||||
gradient_clip_val=args.max_grad_norm,
|
|
||||||
checkpoint_callback=checkpoint_callback,
|
|
||||||
callbacks=[LoggingCallback()],
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
train_params["use_amp"] = args.fp16
|
train_params["use_amp"] = args.fp16
|
||||||
@ -269,12 +316,27 @@ def generic_train(model: BaseTransformer, args: argparse.Namespace):
|
|||||||
train_params["num_tpu_cores"] = args.n_tpu_cores
|
train_params["num_tpu_cores"] = args.n_tpu_cores
|
||||||
train_params["gpus"] = 0
|
train_params["gpus"] = 0
|
||||||
|
|
||||||
if args.n_gpu > 1:
|
if args.gpus > 1:
|
||||||
train_params["distributed_backend"] = "ddp"
|
train_params["distributed_backend"] = "ddp"
|
||||||
|
|
||||||
trainer = pl.Trainer(**train_params)
|
trainer = pl.Trainer(
|
||||||
|
logger=logger,
|
||||||
|
accumulate_grad_batches=args.gradient_accumulation_steps,
|
||||||
|
gpus=args.gpus,
|
||||||
|
max_epochs=args.num_train_epochs,
|
||||||
|
early_stop_callback=early_stopping_callback,
|
||||||
|
gradient_clip_val=args.max_grad_norm,
|
||||||
|
checkpoint_callback=checkpoint_callback,
|
||||||
|
callbacks=[logging_callback] + extra_callbacks,
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
|
val_check_interval=args.val_check_interval,
|
||||||
|
weights_summary=None,
|
||||||
|
resume_from_checkpoint=args.resume_from_checkpoint,
|
||||||
|
**train_params,
|
||||||
|
)
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
trainer.logger.log_hyperparams(args)
|
||||||
|
trainer.logger.save()
|
||||||
return trainer
|
return trainer
|
||||||
|
@ -5,5 +5,6 @@ psutil
|
|||||||
sacrebleu
|
sacrebleu
|
||||||
rouge-score
|
rouge-score
|
||||||
tensorflow_datasets
|
tensorflow_datasets
|
||||||
pytorch-lightning==0.7.3 # April 10, 2020 release
|
pytorch-lightning==0.7.6
|
||||||
matplotlib
|
matplotlib
|
||||||
|
git-python==1.0.3
|
||||||
|
@ -1,47 +1,70 @@
|
|||||||
### Get CNN Data
|
### Data
|
||||||
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
|
|
||||||
|
|
||||||
|
CNN/DailyMail data
|
||||||
```bash
|
```bash
|
||||||
|
cd examples/summarization
|
||||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
|
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
|
||||||
tar -xzvf cnn_dm.tgz
|
tar -xzvf cnn_dm.tgz
|
||||||
|
export CNN_DIR=${PWD}/cnn_dm
|
||||||
```
|
```
|
||||||
|
|
||||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||||
|
|
||||||
|
XSUM Data:
|
||||||
|
```bash
|
||||||
|
cd examples/summarization
|
||||||
|
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
||||||
|
tar -xzvf xsum.tar.gz
|
||||||
|
export XSUM_DIR=${PWD}/xsum
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Evaluation
|
### Evaluation
|
||||||
|
|
||||||
To create summaries for each article in dataset, run:
|
To create summaries for each article in dataset, run:
|
||||||
```bash
|
```bash
|
||||||
python evaluate_cnn.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
|
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
|
||||||
```
|
```
|
||||||
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||||
|
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
Run/modify `finetune_bart.sh` or `finetune_t5.sh`
|
Run/modify `finetune.sh`
|
||||||
|
|
||||||
### Stanford CoreNLP Setup
|
The following command should work on a 16GB GPU:
|
||||||
|
```bash
|
||||||
|
export me=`git config user.name`
|
||||||
|
./finetune.sh \
|
||||||
|
--data_dir $XSUM_DIR \
|
||||||
|
--train_batch_size=1 \
|
||||||
|
--eval_batch_size=1 \
|
||||||
|
--output_dir="$me"_xsum_results \
|
||||||
|
--num_train_epochs 1
|
||||||
```
|
```
|
||||||
ptb_tokenize () {
|
|
||||||
cat $1 | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > $2
|
|
||||||
}
|
|
||||||
|
|
||||||
sudo apt install openjdk-8-jre-headless
|
Tips:
|
||||||
sudo apt-get install ant
|
- 1 epoch at batch size 1 for bart-large takes 24 hours, requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||||
wget http://nlp.stanford.edu/software/stanford-corenlp-full-2018-10-05.zip
|
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see below)
|
||||||
unzip stanford-corenlp-full-2018-10-05.zip
|
- `fp16_opt_level=O1` (the default works best).
|
||||||
cd stanford-corenlp-full-2018-10-05
|
- If you are finetuning on your own dataset, start from `bart-large-cnn` if you want long summaries and `bart-large-xsum` if you want short summaries.
|
||||||
export CLASSPATH=stanford-corenlp-3.9.2.jar:stanford-corenlp-3.9.2-models.jar
|
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||||
```
|
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||||
Then run `ptb_tokenize` on `test.target` and your generated hypotheses.
|
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||||
### Rouge Setup
|
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||||
Install `files2rouge` following the instructions at [here](https://github.com/pltrdy/files2rouge).
|
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||||
I also needed to run `sudo apt-get install libxml-parser-perl`
|
|
||||||
|
|
||||||
```python
|
### XSUM Shared Task
|
||||||
from files2rouge import files2rouge
|
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
|
||||||
from files2rouge import settings
|
Here is an example command
|
||||||
files2rouge.run(<path_to_tokenized_hypo>,
|
```bash
|
||||||
<path_to_tokenized_target>,
|
export me=`git config user.name`
|
||||||
saveto='rouge_output.txt')
|
./finetune.sh \
|
||||||
|
--data_dir $XSUM_DIR \
|
||||||
|
--output_dir "$me"_xsum_frozen_embs \
|
||||||
|
--logger wandb_shared \
|
||||||
|
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||||
|
--num_train_epochs 6
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_summarization/table?workspace=user-)
|
||||||
|
85
examples/summarization/callbacks.py
Normal file
85
examples/summarization/callbacks.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
|
from pytorch_lightning.utilities import rank_zero_only
|
||||||
|
|
||||||
|
|
||||||
|
def count_trainable_parameters(model):
|
||||||
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||||
|
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqLoggingCallback(pl.Callback):
|
||||||
|
def _write_logs(
|
||||||
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||||
|
) -> None:
|
||||||
|
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||||
|
if not pl_module.is_logger():
|
||||||
|
return
|
||||||
|
metrics = trainer.callback_metrics
|
||||||
|
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||||
|
# Log results
|
||||||
|
od = Path(pl_module.hparams.output_dir)
|
||||||
|
if type_path == "test":
|
||||||
|
results_file = od / "test_results.txt"
|
||||||
|
generations_file = od / "test_generations.txt"
|
||||||
|
else:
|
||||||
|
results_file = od / f"{type_path}_results_{trainer.global_step:05d}.txt"
|
||||||
|
generations_file = od / f"{type_path}_generations_{trainer.global_step:05d}.txt"
|
||||||
|
|
||||||
|
with open(results_file, "a+") as writer:
|
||||||
|
for key in sorted(metrics):
|
||||||
|
if key in ["log", "progress_bar", "preds"]:
|
||||||
|
continue
|
||||||
|
val = metrics[key]
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
val = val.item()
|
||||||
|
msg = f"{key}: {val:.6f}\n"
|
||||||
|
writer.write(msg)
|
||||||
|
|
||||||
|
if not save_generations:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "preds" in metrics:
|
||||||
|
content = "\n".join(metrics["preds"])
|
||||||
|
generations_file.open("w+").write(content)
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def on_train_start(self, trainer, pl_module):
|
||||||
|
try:
|
||||||
|
npars = pl_module.model.model.num_parameters()
|
||||||
|
except AttributeError:
|
||||||
|
npars = pl_module.model.num_parameters()
|
||||||
|
|
||||||
|
n_trainable_pars = count_trainable_parameters(pl_module)
|
||||||
|
# mp stands for million parameters
|
||||||
|
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
|
return self._write_logs(trainer, pl_module, "val")
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
|
return self._write_logs(trainer, pl_module, "test")
|
||||||
|
|
||||||
|
|
||||||
|
def get_rouge2_checkpoint_callback(output_dir):
|
||||||
|
"""Saves the best model by validation ROUGE2 score."""
|
||||||
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
filepath=os.path.join(output_dir, "{val_avg_rouge2:.4f}-{step_count}"),
|
||||||
|
monitor="val_rouge",
|
||||||
|
mode="max",
|
||||||
|
save_top_k=1,
|
||||||
|
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
|
)
|
||||||
|
return checkpoint_callback
|
448
examples/summarization/distillation.py
Normal file
448
examples/summarization/distillation.py
Normal file
@ -0,0 +1,448 @@
|
|||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from lightning_base import generic_train
|
||||||
|
from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .finetune import SummarizationModule
|
||||||
|
from .initialization_utils import init_student, copy_layers
|
||||||
|
from .utils import (
|
||||||
|
use_task_specific_params,
|
||||||
|
SummarizationDataset,
|
||||||
|
pickle_load,
|
||||||
|
freeze_params,
|
||||||
|
assert_all_frozen,
|
||||||
|
any_requires_grad,
|
||||||
|
)
|
||||||
|
from .finetune import main as ft_main
|
||||||
|
except ImportError:
|
||||||
|
from finetune import SummarizationModule
|
||||||
|
from finetune import main as ft_main
|
||||||
|
from initialization_utils import init_student, copy_layers
|
||||||
|
from utils import (
|
||||||
|
use_task_specific_params,
|
||||||
|
SummarizationDataset,
|
||||||
|
pickle_load,
|
||||||
|
freeze_params,
|
||||||
|
assert_all_frozen,
|
||||||
|
any_requires_grad,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationDistiller(SummarizationModule):
|
||||||
|
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||||
|
|
||||||
|
def __init__(self, hparams):
|
||||||
|
assert Path(hparams.data_dir).exists()
|
||||||
|
|
||||||
|
d_layers_to_copy, student, student_cfg, teacher = self.pre_init(hparams)
|
||||||
|
|
||||||
|
super().__init__(hparams, model=student, config=student_cfg)
|
||||||
|
self.teacher = teacher
|
||||||
|
use_task_specific_params(self.teacher, "summarization")
|
||||||
|
freeze_params(self.teacher)
|
||||||
|
self.sanity_check_gradients()
|
||||||
|
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||||
|
self.temperature = 2.0
|
||||||
|
self.alpha_mlm = hparams.alpha_mlm
|
||||||
|
self.alpha_ce = hparams.alpha_ce
|
||||||
|
self.alpha_hid = hparams.alpha_hid
|
||||||
|
# self.alpha_cos = hparams.alpha_cos
|
||||||
|
self.alpha_encoder_loss = self.hparams.alpha_encoder_loss
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def sanity_check_gradients(self):
|
||||||
|
assert_all_frozen(self.teacher)
|
||||||
|
assert_all_frozen(self.model.model.decoder.embed_tokens)
|
||||||
|
assert_all_frozen(self.model.model.encoder.embed_tokens)
|
||||||
|
if self.different_encoder:
|
||||||
|
assert any_requires_grad(self.model.model.encoder)
|
||||||
|
else:
|
||||||
|
freeze_params(self.model.model.encoder)
|
||||||
|
del self.teacher.model.encoder
|
||||||
|
|
||||||
|
def pre_init(self, hparams):
|
||||||
|
# Dump empty student model at a path, then call from_pretrained on it
|
||||||
|
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
|
||||||
|
student_updates = {
|
||||||
|
"decoder_layers": hparams.student_decoder_layers,
|
||||||
|
"encoder_layers": hparams.student_encoder_layers,
|
||||||
|
}
|
||||||
|
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||||
|
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
||||||
|
hparams.d_layer_to_copy = d_layers_to_copy
|
||||||
|
hparams.e_layer_to_copy = e_layers_to_copy
|
||||||
|
kw = teacher.config.to_diff_dict()
|
||||||
|
kw.update(student_updates)
|
||||||
|
# Copy weights
|
||||||
|
student_cfg = BartConfig(**kw)
|
||||||
|
student = BartForConditionalGeneration(student_cfg)
|
||||||
|
student, _ = init_student(student, teacher)
|
||||||
|
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||||
|
Path(hparams.output_dir).mkdir(exist_ok=True)
|
||||||
|
return d_layers_to_copy, student, student_cfg, teacher
|
||||||
|
|
||||||
|
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
||||||
|
if teacher.config.model_type == "t5":
|
||||||
|
return self.copy_t5_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||||
|
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.encoder_layers
|
||||||
|
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
|
||||||
|
if self.different_decoder:
|
||||||
|
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
|
||||||
|
if self.different_encoder:
|
||||||
|
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
|
||||||
|
|
||||||
|
def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
||||||
|
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.num_layers
|
||||||
|
self.different_decoder = hparams.student_decoder_layers != teacher.config.num_layers
|
||||||
|
if self.different_decoder:
|
||||||
|
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
|
||||||
|
if self.different_encoder:
|
||||||
|
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
||||||
|
|
||||||
|
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||||
|
n_obs = self.n_obs[type_path]
|
||||||
|
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
||||||
|
if mask is not None:
|
||||||
|
# mask has False at padding_idx
|
||||||
|
sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
|
||||||
|
s_logits_slct = torch.masked_select(student_outputs, sel_mask)
|
||||||
|
t_logits_slct = torch.masked_select(teacher_outputs, sel_mask)
|
||||||
|
else:
|
||||||
|
t_logits_slct = teacher_outputs
|
||||||
|
s_logits_slct = student_outputs
|
||||||
|
return F.mse_loss(s_logits_slct, t_logits_slct)
|
||||||
|
|
||||||
|
def calc_ce_loss(self, mask, s_logits, t_logits):
|
||||||
|
if mask is not None:
|
||||||
|
# mask has False at padding_idx
|
||||||
|
sel_mask = mask[:, :, None].expand_as(s_logits)
|
||||||
|
s_logits_slct = torch.masked_select(
|
||||||
|
s_logits, sel_mask
|
||||||
|
) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||||
|
t_logits_slct = torch.masked_select(
|
||||||
|
t_logits, sel_mask
|
||||||
|
) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||||
|
else:
|
||||||
|
t_logits_slct = t_logits
|
||||||
|
s_logits_slct = s_logits # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||||
|
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||||
|
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||||
|
assert t_logits_slct.size() == s_logits_slct.size()
|
||||||
|
loss_ce = (
|
||||||
|
self.ce_loss_fct(
|
||||||
|
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||||
|
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||||
|
)
|
||||||
|
* (self.temperature) ** 2
|
||||||
|
)
|
||||||
|
return loss_ce, s_logits_slct, t_logits_slct
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||||
|
|
||||||
|
model = self.model
|
||||||
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": self.hparams.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||||
|
self.opt = optimizer
|
||||||
|
return [optimizer]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_model_specific_args(parser, root_dir):
|
||||||
|
SummarizationModule.add_model_specific_args(parser, root_dir)
|
||||||
|
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
|
||||||
|
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||||
|
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||||
|
# parser.add_argument("--alpha_cos", default=0.0, type=float)
|
||||||
|
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||||
|
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||||
|
parser.add_argument(
|
||||||
|
"--student_decoder_layers", default=12, type=int, required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--student_encoder_layers", default=12, type=int, required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_teacher", action="store_true", default=False,
|
||||||
|
)
|
||||||
|
parser.add_argument( # TODO: remove
|
||||||
|
"--enc_only", action="store_true", default=False,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def _step(self, batch):
|
||||||
|
# assert is_frozen(self.teacher)
|
||||||
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||||
|
decoder_input_ids = y[:, :-1].contiguous()
|
||||||
|
labels = y[:, 1:].clone()
|
||||||
|
labels[y[:, 1:] == pad_token_id] = -100
|
||||||
|
# noinspection PyCallingNonCallable
|
||||||
|
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=src_mask,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
labels=labels,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_attentions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def zero_tensor():
|
||||||
|
return torch.tensor(0.0).type_as(sloss)
|
||||||
|
|
||||||
|
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||||
|
if self.different_encoder:
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder(
|
||||||
|
input_ids, attention_mask=src_mask, output_hidden_states=True
|
||||||
|
)
|
||||||
|
if self.hparams.alpha_encoder_loss > 0:
|
||||||
|
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
|
||||||
|
|
||||||
|
hid_loss_enc = self.calc_hidden_loss(
|
||||||
|
src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
|
||||||
|
)
|
||||||
|
|
||||||
|
teacher_enc_outputs = (enc_outputs,)
|
||||||
|
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
tloss, tlogits, tdec_hidden, _ = self.teacher(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=src_mask,
|
||||||
|
encoder_outputs=teacher_enc_outputs,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
lm_labels=labels,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||||
|
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
||||||
|
if self.alpha_hid > 0:
|
||||||
|
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||||
|
|
||||||
|
blended_loss = (
|
||||||
|
self.alpha_ce * loss_ce
|
||||||
|
+ self.alpha_mlm * sloss
|
||||||
|
+ self.hparams.alpha_encoder_loss * loss_encoder
|
||||||
|
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||||
|
)
|
||||||
|
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||||
|
|
||||||
|
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
|
||||||
|
assert not isinstance(
|
||||||
|
hidden_states, torch.Tensor
|
||||||
|
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}"
|
||||||
|
assert not isinstance(
|
||||||
|
hidden_states_T, torch.Tensor
|
||||||
|
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
|
||||||
|
mask = attention_mask.to(hidden_states[0])
|
||||||
|
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||||
|
hidden_losses = [
|
||||||
|
(F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum()
|
||||||
|
/ valid_count
|
||||||
|
for i, j in enumerate(matches)
|
||||||
|
]
|
||||||
|
return sum(hidden_losses)
|
||||||
|
|
||||||
|
|
||||||
|
class T5SummarizationDistiller(SummarizationDistiller):
|
||||||
|
def pre_init(self, hparams):
|
||||||
|
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
||||||
|
n_layer = hparams.student_decoder_layers
|
||||||
|
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
|
||||||
|
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
|
||||||
|
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
|
||||||
|
student_updates = {"num_layers": n_layer}
|
||||||
|
hparams.d_layer_to_copy = d_layers_to_copy
|
||||||
|
hparams.e_layer_to_copy = e_layers_to_copy
|
||||||
|
kw = teacher.config.to_diff_dict()
|
||||||
|
|
||||||
|
kw.update(student_updates)
|
||||||
|
# Copy weights
|
||||||
|
student_cfg = T5Config(**kw)
|
||||||
|
student = T5ForConditionalGeneration(student_cfg)
|
||||||
|
student, _ = init_student(student, teacher)
|
||||||
|
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||||
|
Path(hparams.output_dir).mkdir(exist_ok=True)
|
||||||
|
task_specific_params = student.config.task_specific_params
|
||||||
|
if task_specific_params is not None:
|
||||||
|
student.config.update(task_specific_params.get("summarization", {}))
|
||||||
|
return d_layers_to_copy, student, student_cfg, teacher
|
||||||
|
|
||||||
|
def freeze_embeds(self):
|
||||||
|
freeze_params(self.model.shared)
|
||||||
|
for d in [self.model.encoder, self.model.decoder]:
|
||||||
|
freeze_params(d.embed_tokens)
|
||||||
|
|
||||||
|
def sanity_check_gradients(self):
|
||||||
|
"""T5"""
|
||||||
|
assert_all_frozen(self.teacher)
|
||||||
|
assert_all_frozen(self.model.decoder.embed_tokens)
|
||||||
|
assert_all_frozen(self.model.encoder.embed_tokens)
|
||||||
|
if self.different_encoder:
|
||||||
|
assert any_requires_grad(self.model.encoder)
|
||||||
|
else:
|
||||||
|
freeze_params(self.model.encoder)
|
||||||
|
del self.teacher.model.encoder
|
||||||
|
if self.different_decoder:
|
||||||
|
assert any_requires_grad(self.model.decoder)
|
||||||
|
else:
|
||||||
|
freeze_params(self.model.decoder) # TODO(SS): very suspicious
|
||||||
|
|
||||||
|
def _step(self, batch):
|
||||||
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||||
|
decoder_input_ids = y[:, :-1].contiguous()
|
||||||
|
labels = y[:, 1:].clone()
|
||||||
|
labels[y[:, 1:] == pad_token_id] = -100
|
||||||
|
# noinspection PyCallingNonCallable
|
||||||
|
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||||
|
|
||||||
|
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||||
|
source_ids,
|
||||||
|
attention_mask=source_mask,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
labels=labels,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_attentions=False,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def zero_tensor():
|
||||||
|
return torch.tensor(0.0).type_as(sloss)
|
||||||
|
|
||||||
|
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||||
|
if self.different_encoder:
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder(
|
||||||
|
source_ids, attention_mask=source_mask, output_hidden_states=True, use_cache=False,
|
||||||
|
)
|
||||||
|
if self.hparams.alpha_encoder_loss > 0:
|
||||||
|
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask)
|
||||||
|
|
||||||
|
hid_loss_enc = self.calc_hidden_loss(
|
||||||
|
source_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
|
||||||
|
)
|
||||||
|
|
||||||
|
teacher_enc_outputs = (enc_outputs,)
|
||||||
|
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
tloss, tlogits, tdec_hidden, _ = self.teacher(
|
||||||
|
source_ids,
|
||||||
|
attention_mask=source_mask,
|
||||||
|
encoder_outputs=teacher_enc_outputs,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
lm_labels=labels,
|
||||||
|
output_hidden_states=True,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
||||||
|
if self.alpha_hid > 0:
|
||||||
|
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||||
|
|
||||||
|
blended_loss = (
|
||||||
|
self.alpha_ce * loss_ce
|
||||||
|
+ self.alpha_mlm * sloss
|
||||||
|
+ self.hparams.alpha_encoder_loss * loss_encoder
|
||||||
|
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||||
|
)
|
||||||
|
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||||
|
|
||||||
|
|
||||||
|
def create_module(args):
|
||||||
|
t5 = "t5" in args.model_name_or_path
|
||||||
|
if args.no_teacher:
|
||||||
|
assert not args.enc_only
|
||||||
|
module_cls = SummarizationModule
|
||||||
|
elif t5:
|
||||||
|
module_cls = T5SummarizationDistiller
|
||||||
|
elif args.enc_only:
|
||||||
|
raise ValueError("Deleted that")
|
||||||
|
else:
|
||||||
|
module_cls = SummarizationDistiller
|
||||||
|
args.setup_cls: str = module_cls.__name__
|
||||||
|
model = module_cls(args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
||||||
|
exp_dir = ckpt_path.parent
|
||||||
|
if dest_dir is None:
|
||||||
|
dest_dir = exp_dir
|
||||||
|
clash = list(dest_dir.glob("test_generations*"))
|
||||||
|
if clash:
|
||||||
|
print(f"SKIPPING to avoid overwriting {clash}")
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
if "hparams" in ckpt:
|
||||||
|
args = argparse.Namespace(**ckpt["hparams"])
|
||||||
|
else:
|
||||||
|
args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl"))
|
||||||
|
args.resume_from_checkpoint = str(ckpt_path)
|
||||||
|
args.do_train = False
|
||||||
|
args.output_dir = str(dest_dir)
|
||||||
|
args.n_gpu = 1
|
||||||
|
args.eval_batch_size = 16
|
||||||
|
Path(args.output_dir).mkdir(exist_ok=True)
|
||||||
|
model = create_module(args)
|
||||||
|
trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False)
|
||||||
|
trainer.test(model)
|
||||||
|
|
||||||
|
|
||||||
|
def get_layers_to_copy(n_to_get, tot):
|
||||||
|
all_layers = list(range(tot))
|
||||||
|
if tot == 12: # Alternating for special cases
|
||||||
|
layers_to_copy = { # maps # layers in student -> which teacher layers to copy
|
||||||
|
6: [0, 2, 4, 7, 9, 11],
|
||||||
|
1: [11],
|
||||||
|
3: [0, 6, 11],
|
||||||
|
2: [0, 11],
|
||||||
|
4: [0, 4, 8, 11],
|
||||||
|
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||||
|
12: all_layers,
|
||||||
|
}
|
||||||
|
return layers_to_copy[n_to_get]
|
||||||
|
else:
|
||||||
|
return all_layers[:n_to_get]
|
||||||
|
|
||||||
|
|
||||||
|
def distill_main(args):
|
||||||
|
Path(args.output_dir).mkdir(exist_ok=True)
|
||||||
|
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||||
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
|
|
||||||
|
model = create_module(args)
|
||||||
|
return ft_main(args, model=model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
distill_main(args)
|
@ -1,100 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from rouge_score import rouge_scorer, scoring
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from transformers import AutoModelWithLMHead, AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
def chunks(lst, n):
|
|
||||||
"""Yield successive n-sized chunks from lst."""
|
|
||||||
for i in range(0, len(lst), n):
|
|
||||||
yield lst[i : i + n]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_summaries(
|
|
||||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
|
|
||||||
):
|
|
||||||
fout = Path(out_file).open("w", encoding="utf-8")
|
|
||||||
model = AutoModelWithLMHead.from_pretrained(model_name).to(device)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
|
|
||||||
# update config with summarization specific params
|
|
||||||
task_specific_params = model.config.task_specific_params
|
|
||||||
if task_specific_params is not None:
|
|
||||||
model.config.update(task_specific_params.get("summarization", {}))
|
|
||||||
|
|
||||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
|
||||||
if "t5" in model_name:
|
|
||||||
batch = [model.config.prefix + text for text in batch]
|
|
||||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
summaries = model.generate(**dct)
|
|
||||||
|
|
||||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
||||||
for hypothesis in dec:
|
|
||||||
fout.write(hypothesis + "\n")
|
|
||||||
fout.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_rouge(output_lns, reference_lns, score_path):
|
|
||||||
score_file = Path(score_path).open("w")
|
|
||||||
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
|
||||||
aggregator = scoring.BootstrapAggregator()
|
|
||||||
|
|
||||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
|
||||||
scores = scorer.score(reference_ln, output_ln)
|
|
||||||
aggregator.add_scores(scores)
|
|
||||||
|
|
||||||
result = aggregator.aggregate()
|
|
||||||
score_file.write(
|
|
||||||
"ROUGE_1: \n{} \n\n ROUGE_2: \n{} \n\n ROUGE_L: \n{} \n\n".format(
|
|
||||||
result["rouge1"], result["rouge2"], result["rougeL"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_generate():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"input_path", type=str, help="like cnn_dm/test.source or cnn_dm/test_articles_input.txt",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"output_path", type=str, help="where to save summaries",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"model_name",
|
|
||||||
type=str,
|
|
||||||
default="facebook/bart-large-cnn",
|
|
||||||
help="like bart-large-cnn,'t5-small', 't5-base', 't5-large', 't5-3b', 't5-11b",
|
|
||||||
)
|
|
||||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
|
||||||
parser.add_argument(
|
|
||||||
"--score_path", type=str, required=False, help="where to save the rouge score",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
|
||||||
|
|
||||||
generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
|
|
||||||
if args.score_path is not None:
|
|
||||||
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
|
||||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
|
||||||
|
|
||||||
calculate_rouge(output_lns, reference_lns, args.score_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run_generate()
|
|
@ -3,91 +3,169 @@ import glob
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
|
from transformers import get_linear_schedule_with_warmup
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .utils import SummarizationDataset
|
from .utils import (
|
||||||
|
use_task_specific_params,
|
||||||
|
SummarizationDataset,
|
||||||
|
lmap,
|
||||||
|
flatten_list,
|
||||||
|
pickle_save,
|
||||||
|
save_git_info,
|
||||||
|
freeze_params,
|
||||||
|
calculate_rouge,
|
||||||
|
get_git_info,
|
||||||
|
ROUGE_KEYS,
|
||||||
|
)
|
||||||
|
from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import SummarizationDataset
|
from utils import (
|
||||||
|
use_task_specific_params,
|
||||||
|
SummarizationDataset,
|
||||||
|
lmap,
|
||||||
|
flatten_list,
|
||||||
|
pickle_save,
|
||||||
|
save_git_info,
|
||||||
|
freeze_params,
|
||||||
|
calculate_rouge,
|
||||||
|
get_git_info,
|
||||||
|
ROUGE_KEYS,
|
||||||
|
)
|
||||||
|
from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SummarizationTrainer(BaseTransformer):
|
class SummarizationModule(BaseTransformer):
|
||||||
|
mode = "summarization"
|
||||||
|
loss_names = ["loss"]
|
||||||
|
|
||||||
mode = "language-modeling"
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||||
|
use_task_specific_params(self.model, "summarization")
|
||||||
|
save_git_info(self.hparams.output_dir)
|
||||||
|
self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
|
||||||
|
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||||
|
self.step_count = 0
|
||||||
|
self.metrics = {"train": [], "val": [], "test": []}
|
||||||
|
|
||||||
def __init__(self, hparams):
|
|
||||||
super().__init__(hparams, num_labels=None, mode=self.mode)
|
|
||||||
self.dataset_kwargs: dict = dict(
|
self.dataset_kwargs: dict = dict(
|
||||||
data_dir=self.hparams.data_dir,
|
data_dir=self.hparams.data_dir,
|
||||||
max_source_length=self.hparams.max_source_length,
|
max_source_length=self.hparams.max_source_length,
|
||||||
max_target_length=self.hparams.max_target_length,
|
prefix=self.model.config.prefix or "",
|
||||||
)
|
)
|
||||||
|
n_observations_per_split = {
|
||||||
|
"train": self.hparams.n_train,
|
||||||
|
"val": self.hparams.n_val,
|
||||||
|
"test": self.hparams.n_test,
|
||||||
|
}
|
||||||
|
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None):
|
self.target_lens = {
|
||||||
return self.model(
|
"train": self.hparams.max_target_length,
|
||||||
input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, lm_labels=lm_labels,
|
"val": self.hparams.val_max_target_length,
|
||||||
|
"test": self.hparams.test_max_target_length,
|
||||||
|
}
|
||||||
|
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||||
|
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||||
|
|
||||||
|
if self.hparams.freeze_embeds:
|
||||||
|
self.freeze_embeds()
|
||||||
|
if self.hparams.freeze_encoder:
|
||||||
|
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
||||||
|
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||||
|
self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu
|
||||||
|
|
||||||
|
def freeze_embeds(self):
|
||||||
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
|
if self.model.config.model_type == "bart":
|
||||||
|
freeze_params(self.model.model.shared)
|
||||||
|
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||||
|
freeze_params(d.embed_positions)
|
||||||
|
freeze_params(d.embed_tokens)
|
||||||
|
else:
|
||||||
|
freeze_params(self.model.shared)
|
||||||
|
for d in [self.model.encoder, self.model.decoder]:
|
||||||
|
freeze_params(d.embed_tokens)
|
||||||
|
|
||||||
|
def forward(self, input_ids, **kwargs):
|
||||||
|
return self.model(input_ids, **kwargs)
|
||||||
|
|
||||||
|
def ids_to_clean_text(self, generated_ids: List[int]):
|
||||||
|
gen_text = self.tokenizer.batch_decode(
|
||||||
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
)
|
)
|
||||||
|
return lmap(str.strip, gen_text)
|
||||||
|
|
||||||
def _step(self, batch):
|
def _step(self, batch: dict) -> Tuple:
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"]
|
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||||
y_ids = y[:, :-1].contiguous()
|
y_ids = y[:, :-1].contiguous()
|
||||||
lm_labels = y[:, 1:].clone()
|
lm_labels = y[:, 1:].clone()
|
||||||
lm_labels[y[:, 1:] == pad_token_id] = -100
|
lm_labels[y[:, 1:] == pad_token_id] = -100
|
||||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,)
|
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
|
||||||
|
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
|
return (loss,)
|
||||||
|
|
||||||
return loss
|
def training_step(self, batch, batch_idx) -> Dict:
|
||||||
|
loss_tensors = self._step(batch)
|
||||||
|
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||||
|
return {"loss": loss_tensors[0], "log": logs}
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx) -> Dict:
|
||||||
loss = self._step(batch)
|
return self._generative_step(batch)
|
||||||
|
|
||||||
tensorboard_logs = {"train_loss": loss}
|
def validation_end(self, outputs, prefix="val") -> Dict:
|
||||||
return {"loss": loss, "log": tensorboard_logs}
|
self.step_count += 1
|
||||||
|
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||||
|
loss = losses["loss"]
|
||||||
|
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
|
||||||
|
rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
|
||||||
|
rouges.update({k: v.item() for k, v in losses.items()})
|
||||||
|
losses.update(rouges)
|
||||||
|
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||||
|
metrics["step_count"] = self.step_count
|
||||||
|
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||||
|
preds = flatten_list([x["preds"] for x in outputs])
|
||||||
|
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def save_metrics(self, metrics, prefix) -> None:
|
||||||
loss = self._step(batch)
|
self.metrics[prefix].append(metrics)
|
||||||
return {"val_loss": loss}
|
pickle_save(self.metrics, self.metrics_save_path)
|
||||||
|
|
||||||
def validation_end(self, outputs):
|
def _generative_step(self, batch):
|
||||||
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
|
||||||
tensorboard_logs = {"val_loss": avg_loss}
|
|
||||||
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
|
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||||
# NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
|
# TODO(SS): task specific params
|
||||||
generated_ids = self.model.generate(
|
|
||||||
input_ids=source_ids,
|
|
||||||
attention_mask=source_mask,
|
|
||||||
num_beams=1,
|
|
||||||
max_length=80,
|
|
||||||
repetition_penalty=2.5,
|
|
||||||
length_penalty=1.0,
|
|
||||||
early_stopping=True,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
preds = [
|
|
||||||
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
||||||
for g in generated_ids
|
|
||||||
]
|
|
||||||
target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]
|
|
||||||
loss = self._step(batch)
|
|
||||||
|
|
||||||
return {"val_loss": loss, "preds": preds, "target": target}
|
t0 = time.time()
|
||||||
|
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
||||||
|
gen_time = time.time() - t0
|
||||||
|
preds = self.ids_to_clean_text(generated_ids)
|
||||||
|
target = self.ids_to_clean_text(y)
|
||||||
|
loss_tensors = self._step(batch)
|
||||||
|
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||||
|
rouge: Dict = calculate_rouge(preds, target)
|
||||||
|
summ_len = np.mean(lmap(len, generated_ids))
|
||||||
|
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
|
||||||
|
return base_metrics
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
return self._generative_step(batch)
|
||||||
|
|
||||||
def test_end(self, outputs):
|
def test_end(self, outputs):
|
||||||
return self.validation_end(outputs)
|
return self.validation_end(outputs, prefix="test")
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
||||||
@ -102,15 +180,43 @@ class SummarizationTrainer(BaseTransformer):
|
|||||||
|
|
||||||
return self.test_end(outputs)
|
return self.test_end(outputs)
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outputs):
|
||||||
|
self.validation_end(outputs, "val")
|
||||||
|
|
||||||
|
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||||
|
n_obs = self.n_obs[type_path]
|
||||||
|
max_target_length = self.target_lens[type_path]
|
||||||
|
dataset = SummarizationDataset(
|
||||||
|
self.tokenizer,
|
||||||
|
type_path=type_path,
|
||||||
|
n_obs=n_obs,
|
||||||
|
max_target_length=max_target_length,
|
||||||
|
**self.dataset_kwargs,
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
|
||||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||||
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
|
dataset = self.get_dataset(type_path)
|
||||||
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle)
|
sampler = None
|
||||||
|
if self.hparams.sortish_sampler and type_path == "train":
|
||||||
|
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
||||||
|
sampler = dataset.make_sortish_sampler(batch_size)
|
||||||
|
shuffle = False
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
shuffle=shuffle,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
sampler=sampler,
|
||||||
|
)
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
def train_dataloader(self) -> DataLoader:
|
def train_dataloader(self) -> DataLoader:
|
||||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||||
t_total = (
|
t_total = (
|
||||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
|
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||||
// self.hparams.gradient_accumulation_steps
|
// self.hparams.gradient_accumulation_steps
|
||||||
* float(self.hparams.num_train_epochs)
|
* float(self.hparams.num_train_epochs)
|
||||||
)
|
)
|
||||||
@ -129,7 +235,7 @@ class SummarizationTrainer(BaseTransformer):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||||
# Add BART specific options
|
add_generic_args(parser, root_dir)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_source_length",
|
"--max_source_length",
|
||||||
default=1024,
|
default=1024,
|
||||||
@ -144,41 +250,82 @@ class SummarizationTrainer(BaseTransformer):
|
|||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val_max_target_length",
|
||||||
|
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
|
||||||
|
type=int,
|
||||||
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_max_target_length",
|
||||||
|
default=142,
|
||||||
|
type=int,
|
||||||
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
default=None,
|
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.",
|
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--freeze_encoder", action="store_true")
|
||||||
|
parser.add_argument("--freeze_embeds", action="store_true")
|
||||||
|
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||||
|
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||||
|
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||||
|
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||||
|
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args, model=None) -> SummarizationModule:
|
||||||
|
Path(args.output_dir).mkdir(exist_ok=True)
|
||||||
|
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||||
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
|
if model is None:
|
||||||
|
model: BaseTransformer = SummarizationModule(args)
|
||||||
|
if (
|
||||||
|
args.logger == "default"
|
||||||
|
or args.fast_dev_run
|
||||||
|
or str(args.output_dir).startswith("/tmp")
|
||||||
|
or str(args.output_dir).startswith("/var")
|
||||||
|
):
|
||||||
|
logger = True # don't pollute wandb logs unnecessarily
|
||||||
|
elif args.logger == "wandb":
|
||||||
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
# If output_dir not provided, a folder will be generated in pwd
|
logger = WandbLogger(name=model.output_dir.name)
|
||||||
if not args.output_dir:
|
elif args.logger == "wandb_shared":
|
||||||
args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
os.makedirs(args.output_dir)
|
|
||||||
model = SummarizationTrainer(args)
|
|
||||||
trainer = generic_train(model, args)
|
|
||||||
|
|
||||||
# Optionally, predict on dev set and write to output_dir
|
# TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
|
||||||
if args.do_predict:
|
logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
|
||||||
# See https://github.com/huggingface/transformers/issues/3159
|
trainer: pl.Trainer = generic_train(
|
||||||
# pl use this format to create a checkpoint:
|
model,
|
||||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
args,
|
||||||
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
|
logging_callback=Seq2SeqLoggingCallback(),
|
||||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
|
||||||
model = model.load_from_checkpoint(checkpoints[-1])
|
logger=logger,
|
||||||
trainer.test(model)
|
# TODO: early stopping callback seems messed up
|
||||||
|
)
|
||||||
|
if not args.do_predict:
|
||||||
|
return model
|
||||||
|
|
||||||
|
model.hparams.test_checkpoint = ""
|
||||||
|
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
||||||
|
if checkpoints:
|
||||||
|
model.hparams.test_checkpoint = checkpoints[-1]
|
||||||
|
trainer.resume_from_checkpoint = checkpoints[-1]
|
||||||
|
trainer.logger.log_hyperparams(model.hparams)
|
||||||
|
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
add_generic_args(parser, os.getcwd())
|
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||||
parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
23
examples/summarization/finetune.sh
Executable file
23
examples/summarization/finetune.sh
Executable file
@ -0,0 +1,23 @@
|
|||||||
|
export OUTPUT_DIR=bart_cnn_finetune
|
||||||
|
|
||||||
|
# Make output directory if it doesn't exist
|
||||||
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
|
# Add parent directory to python path to access lightning_base.py
|
||||||
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
|
|
||||||
|
# --model_name_or_path=t5-base for t5
|
||||||
|
|
||||||
|
python finetune.py \
|
||||||
|
--model_name_or_path=facebook/bart-large \
|
||||||
|
--learning_rate=3e-5 \
|
||||||
|
--fp16 \
|
||||||
|
--gpus 1 \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--n_val 1000 \
|
||||||
|
--val_check_interval 0.1 \
|
||||||
|
--sortish_sampler \
|
||||||
|
--max_target_length=56 \
|
||||||
|
$@
|
@ -1,18 +0,0 @@
|
|||||||
export OUTPUT_DIR_NAME=bart_sum
|
|
||||||
export CURRENT_DIR=${PWD}
|
|
||||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
|
||||||
|
|
||||||
# Make output directory if it doesn't exist
|
|
||||||
mkdir -p $OUTPUT_DIR
|
|
||||||
|
|
||||||
# Add parent directory to python path to access lightning_base.py
|
|
||||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
|
||||||
|
|
||||||
python finetune.py \
|
|
||||||
--data_dir=./cnn-dailymail/cnn_dm \
|
|
||||||
--model_name_or_path=bart-large \
|
|
||||||
--learning_rate=3e-5 \
|
|
||||||
--train_batch_size=4 \
|
|
||||||
--eval_batch_size=4 \
|
|
||||||
--output_dir=$OUTPUT_DIR \
|
|
||||||
--do_train $@
|
|
20
examples/summarization/initialization_utils.py
Normal file
20
examples/summarization/initialization_utils.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def init_student(student, teacher):
|
||||||
|
teacher_state_dict = teacher.state_dict()
|
||||||
|
info = student.load_state_dict(teacher_state_dict, strict=False)
|
||||||
|
assert info.missing_keys == [], info.missing_keys
|
||||||
|
return student, info
|
||||||
|
|
||||||
|
|
||||||
|
def copy_decoder_layers(teacher, student, l2copy=[0, 2, 4, 7, 9, 11]):
|
||||||
|
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, l2copy)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList, layers_to_copy: List) -> None:
|
||||||
|
layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in layers_to_copy])
|
||||||
|
assert len(student_layers) == len(layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}"
|
||||||
|
student_layers.load_state_dict(layers_to_copy.state_dict())
|
12
examples/summarization/run_distiller.sh
Executable file
12
examples/summarization/run_distiller.sh
Executable file
@ -0,0 +1,12 @@
|
|||||||
|
#CNN_DIR = /home/shleifer/transformers_fork/examples/summarization/bart/cnn_dm
|
||||||
|
|
||||||
|
# Add parent directory to python path to access lightning_base.py
|
||||||
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
|
python distillation.py \
|
||||||
|
--learning_rate=3e-4 \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--fp16 \
|
||||||
|
--val_check_interval 0.1 \
|
||||||
|
$@
|
78
examples/summarization/run_eval.py
Normal file
78
examples/summarization/run_eval.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .finetune import calculate_rouge, use_task_specific_params
|
||||||
|
except ImportError:
|
||||||
|
from finetune import calculate_rouge, use_task_specific_params
|
||||||
|
|
||||||
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def chunks(lst, n):
|
||||||
|
"""Yield successive n-sized chunks from lst."""
|
||||||
|
for i in range(0, len(lst), n):
|
||||||
|
yield lst[i : i + n]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_summaries(
|
||||||
|
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
|
||||||
|
) -> None:
|
||||||
|
fout = Path(out_file).open("w", encoding="utf-8")
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
||||||
|
if fp16:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# update config with summarization specific params
|
||||||
|
use_task_specific_params(model, "summarization")
|
||||||
|
|
||||||
|
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||||
|
if "t5" in model_name:
|
||||||
|
batch = [model.config.prefix + text for text in batch]
|
||||||
|
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
summaries = model.generate(**dct)
|
||||||
|
|
||||||
|
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
for hypothesis in dec:
|
||||||
|
fout.write(hypothesis + "\n")
|
||||||
|
fout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def run_generate():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
|
||||||
|
parser.add_argument("output_path", type=str, help="where to save summaries")
|
||||||
|
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
||||||
|
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
||||||
|
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
|
||||||
|
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
||||||
|
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||||
|
parser.add_argument("--fp16", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||||
|
|
||||||
|
generate_summaries(
|
||||||
|
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
|
||||||
|
)
|
||||||
|
if args.score_path is not None:
|
||||||
|
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
||||||
|
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
||||||
|
|
||||||
|
rouge: dict = calculate_rouge(output_lns, reference_lns)
|
||||||
|
|
||||||
|
json.dump(rouge, open("score_path", "w+"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_generate()
|
@ -7,28 +7,40 @@ import unittest
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from transformers import BartTokenizer
|
from transformers import BartTokenizer
|
||||||
|
|
||||||
from .evaluate_cnn import run_generate
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
from .finetune import main
|
from .finetune import main
|
||||||
from .utils import SummarizationDataset
|
from .run_eval import generate_summaries, run_generate
|
||||||
|
from .utils import SummarizationDataset, lmap, pickle_load
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
FP16_EVER = False
|
||||||
DEFAULT_ARGS = {
|
CHEAP_ARGS = {
|
||||||
|
"logger": "default",
|
||||||
|
"alpha_hid": 0,
|
||||||
|
"freeze_embeds": True,
|
||||||
|
"enc_only": False,
|
||||||
|
"tgt_suffix": "",
|
||||||
|
"resume_from_checkpoint": None,
|
||||||
|
"sortish_sampler": True,
|
||||||
|
"student_decoder_layers": 1,
|
||||||
|
"val_check_interval": 1.0,
|
||||||
"output_dir": "",
|
"output_dir": "",
|
||||||
"fp16": False,
|
"fp16": False,
|
||||||
|
"no_teacher": False,
|
||||||
"fp16_opt_level": "O1",
|
"fp16_opt_level": "O1",
|
||||||
"n_gpu": 1,
|
"gpus": 1 if torch.cuda.is_available() else 0,
|
||||||
"n_tpu_cores": 0,
|
"n_tpu_cores": 0,
|
||||||
"max_grad_norm": 1.0,
|
"max_grad_norm": 1.0,
|
||||||
"do_train": True,
|
"do_train": True,
|
||||||
"do_predict": False,
|
"do_predict": True,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"server_ip": "",
|
"server_ip": "",
|
||||||
"server_port": "",
|
"server_port": "",
|
||||||
@ -36,7 +48,7 @@ DEFAULT_ARGS = {
|
|||||||
"model_type": "bart",
|
"model_type": "bart",
|
||||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
"model_name_or_path": "sshleifer/bart-tiny-random",
|
||||||
"config_name": "",
|
"config_name": "",
|
||||||
"tokenizer_name": "",
|
"tokenizer_name": "facebook/bart-large",
|
||||||
"cache_dir": "",
|
"cache_dir": "",
|
||||||
"do_lower_case": False,
|
"do_lower_case": False,
|
||||||
"learning_rate": 3e-05,
|
"learning_rate": 3e-05,
|
||||||
@ -48,6 +60,17 @@ DEFAULT_ARGS = {
|
|||||||
"eval_batch_size": 2,
|
"eval_batch_size": 2,
|
||||||
"max_source_length": 12,
|
"max_source_length": 12,
|
||||||
"max_target_length": 12,
|
"max_target_length": 12,
|
||||||
|
"val_max_target_length": 12,
|
||||||
|
"test_max_target_length": 12,
|
||||||
|
"fast_dev_run": False,
|
||||||
|
"no_cache": False,
|
||||||
|
"n_train": -1,
|
||||||
|
"n_val": -1,
|
||||||
|
"n_test": -1,
|
||||||
|
"student_encoder_layers": 1,
|
||||||
|
"alpha_loss_encoder": 0.0,
|
||||||
|
"freeze_encoder": False,
|
||||||
|
"auto_scale_batch_size": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -56,6 +79,9 @@ def _dump_articles(path: Path, articles: list):
|
|||||||
f.write("\n".join(articles))
|
f.write("\n".join(articles))
|
||||||
|
|
||||||
|
|
||||||
|
BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute()
|
||||||
|
|
||||||
|
|
||||||
def make_test_data_dir():
|
def make_test_data_dir():
|
||||||
tmp_dir = Path(tempfile.gettempdir())
|
tmp_dir = Path(tempfile.gettempdir())
|
||||||
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
|
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||||
@ -66,6 +92,169 @@ def make_test_data_dir():
|
|||||||
return tmp_dir
|
return tmp_dir
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip("These wont' pass until hidden_states kwarg is merged.")
|
||||||
|
class TestSummarizationDistiller(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
|
||||||
|
def test_bdc_multigpu(self):
|
||||||
|
updates = dict(
|
||||||
|
student_encoder_layers=2,
|
||||||
|
student_decoder_layers=1,
|
||||||
|
no_teacher=True,
|
||||||
|
freeze_encoder=True,
|
||||||
|
gpus=2,
|
||||||
|
sortish_sampler=False,
|
||||||
|
)
|
||||||
|
self._bart_distiller_cli(updates)
|
||||||
|
|
||||||
|
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||||
|
def test_bdc_fp16(self):
|
||||||
|
updates = dict(
|
||||||
|
student_encoder_layers=2,
|
||||||
|
student_decoder_layers=1,
|
||||||
|
alpha_hid=3.0,
|
||||||
|
freeze_encoder=True,
|
||||||
|
gpus=1,
|
||||||
|
fp16=FP16_EVER,
|
||||||
|
fp16_opt_level="O1",
|
||||||
|
)
|
||||||
|
self._bart_distiller_cli(updates)
|
||||||
|
|
||||||
|
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||||
|
def test_bdc_t5_eval_fp16(self):
|
||||||
|
updates = dict(
|
||||||
|
fp16=FP16_EVER,
|
||||||
|
gpus=1,
|
||||||
|
model_type="t5",
|
||||||
|
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||||
|
do_train=False,
|
||||||
|
do_predict=True,
|
||||||
|
tokenizer_name=None,
|
||||||
|
no_teacher=True,
|
||||||
|
)
|
||||||
|
self._bart_distiller_cli(updates, check_contents=False)
|
||||||
|
|
||||||
|
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||||
|
def test_bdc_t5_train_fp16(self):
|
||||||
|
updates = dict(
|
||||||
|
fp16=FP16_EVER,
|
||||||
|
gpus=1,
|
||||||
|
model_type="t5",
|
||||||
|
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||||
|
do_train=True,
|
||||||
|
do_predict=True,
|
||||||
|
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||||
|
no_teacher=True,
|
||||||
|
)
|
||||||
|
self._bart_distiller_cli(updates)
|
||||||
|
|
||||||
|
def test_bdc_no_teacher(self):
|
||||||
|
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,)
|
||||||
|
self._bart_distiller_cli(updates)
|
||||||
|
|
||||||
|
def test_bdc_yes_teacher(self):
|
||||||
|
updates = dict(student_encoder_layers=2, student_decoder_layers=1,)
|
||||||
|
self._bart_distiller_cli(updates)
|
||||||
|
|
||||||
|
def test_bdc_checkpointing(self):
|
||||||
|
|
||||||
|
updates = dict(
|
||||||
|
student_encoder_layers=2,
|
||||||
|
student_decoder_layers=1,
|
||||||
|
num_train_epochs=4,
|
||||||
|
val_check_interval=0.25,
|
||||||
|
alpha_hid=2.0,
|
||||||
|
)
|
||||||
|
model = self._bart_distiller_cli(updates, check_contents=False)
|
||||||
|
|
||||||
|
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||||
|
self.assertEqual(1, len(ckpts))
|
||||||
|
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||||
|
self.assertEqual(len(transformer_ckpts), len(ckpts))
|
||||||
|
new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||||
|
self.assertEqual(len(new_transformer_ckpts), 1)
|
||||||
|
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
||||||
|
out_path = tempfile.mktemp()
|
||||||
|
generate_summaries(examples, out_path, new_transformer_ckpts[0].parent)
|
||||||
|
self.assertTrue(Path(out_path).exists())
|
||||||
|
|
||||||
|
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||||
|
|
||||||
|
def test_bdc_t5(self):
|
||||||
|
updates = dict(
|
||||||
|
student_encoder_layers=1,
|
||||||
|
student_decoder_layers=1,
|
||||||
|
alpha_hid=2.0,
|
||||||
|
teacher="patrickvonplaten/t5-tiny-random",
|
||||||
|
model_type="t5",
|
||||||
|
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||||
|
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||||
|
)
|
||||||
|
self._bart_distiller_cli(updates)
|
||||||
|
|
||||||
|
def test_bdc_t5_eval(self):
|
||||||
|
updates = dict(
|
||||||
|
model_type="t5",
|
||||||
|
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||||
|
do_train=False,
|
||||||
|
do_predict=True,
|
||||||
|
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||||
|
no_teacher=True,
|
||||||
|
)
|
||||||
|
self._bart_distiller_cli(updates, check_contents=False)
|
||||||
|
|
||||||
|
def _bart_distiller_cli(self, updates, check_contents=True):
|
||||||
|
default_updates = dict(
|
||||||
|
model_type="bart",
|
||||||
|
train_batch_size=1,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_train_epochs=2,
|
||||||
|
alpha_mlm=0.2,
|
||||||
|
alpha_ce=0.8,
|
||||||
|
do_predict=True,
|
||||||
|
gpus=1 if torch.cuda.is_available() else 0,
|
||||||
|
model_name_or_path="sshleifer/tinier_bart",
|
||||||
|
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||||
|
val_check_interval=0.5,
|
||||||
|
alpha_encoder_loss=0.4,
|
||||||
|
)
|
||||||
|
default_updates.update(updates)
|
||||||
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
tmp_dir = make_test_data_dir()
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||||
|
|
||||||
|
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||||
|
model = distill_main(argparse.Namespace(**args_d))
|
||||||
|
if not check_contents:
|
||||||
|
return model
|
||||||
|
contents = os.listdir(output_dir)
|
||||||
|
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
||||||
|
contents = {os.path.basename(p) for p in contents}
|
||||||
|
self.assertIn(ckpt_name, contents)
|
||||||
|
self.assertIn("metrics.pkl", contents)
|
||||||
|
self.assertIn("test_generations.txt", contents)
|
||||||
|
self.assertIn("val_generations_1.txt", contents)
|
||||||
|
self.assertIn("val_1_results.txt", contents)
|
||||||
|
self.assertIn("test_results.txt", contents)
|
||||||
|
# self.assertEqual(len(contents), 15)
|
||||||
|
|
||||||
|
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
val_df = pd.DataFrame(metrics["val"])
|
||||||
|
train_df = pd.DataFrame(metrics["train"])
|
||||||
|
test_df = pd.DataFrame(metrics["test"])
|
||||||
|
desired_n_evals = args_d["num_train_epochs"] * 2 + 1
|
||||||
|
self.assertEqual(val_df.shape[0], desired_n_evals) #
|
||||||
|
self.assertEqual(test_df.shape[1], val_df.shape[1])
|
||||||
|
self.assertEqual(train_df.shape[0], 0)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class TestBartExamples(unittest.TestCase):
|
class TestBartExamples(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@ -79,49 +268,31 @@ class TestBartExamples(unittest.TestCase):
|
|||||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
||||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||||
_dump_articles(tmp, articles)
|
_dump_articles(tmp, articles)
|
||||||
testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
|
testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_generate()
|
run_generate()
|
||||||
self.assertTrue(Path(output_file_name).exists())
|
self.assertTrue(Path(output_file_name).exists())
|
||||||
os.remove(Path(output_file_name))
|
os.remove(Path(output_file_name))
|
||||||
|
|
||||||
def test_bart_run_sum_cli(self):
|
|
||||||
args_d: dict = DEFAULT_ARGS.copy()
|
|
||||||
tmp_dir = make_test_data_dir()
|
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
|
||||||
args_d.update(
|
|
||||||
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
|
|
||||||
)
|
|
||||||
main(argparse.Namespace(**args_d))
|
|
||||||
args_d.update({"do_train": False, "do_predict": True})
|
|
||||||
|
|
||||||
main(argparse.Namespace(**args_d))
|
|
||||||
contents = os.listdir(output_dir)
|
|
||||||
expected_contents = {
|
|
||||||
"checkpointepoch=0.ckpt",
|
|
||||||
"test_results.txt",
|
|
||||||
}
|
|
||||||
created_files = {os.path.basename(p) for p in contents}
|
|
||||||
self.assertSetEqual(expected_contents, created_files)
|
|
||||||
|
|
||||||
def test_t5_run_sum_cli(self):
|
def test_t5_run_sum_cli(self):
|
||||||
args_d: dict = DEFAULT_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir()
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||||
args_d.update(
|
args_d.update(
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
model_type="t5",
|
model_type="t5",
|
||||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||||
|
tokenizer_name=None, # "patrickvonplaten/t5-tiny-random",
|
||||||
train_batch_size=2,
|
train_batch_size=2,
|
||||||
eval_batch_size=2,
|
eval_batch_size=2,
|
||||||
n_gpu=0,
|
gpus=0,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
do_predict=True,
|
do_predict=True,
|
||||||
)
|
)
|
||||||
main(argparse.Namespace(**args_d))
|
assert "n_train" in args_d
|
||||||
|
args = argparse.Namespace(**args_d)
|
||||||
# args_d.update({"do_train": False, "do_predict": True})
|
main(args)
|
||||||
# main(argparse.Namespace(**args_d))
|
|
||||||
|
|
||||||
def test_bart_summarization_dataset(self):
|
def test_bart_summarization_dataset(self):
|
||||||
tmp_dir = Path(tempfile.gettempdir())
|
tmp_dir = Path(tempfile.gettempdir())
|
||||||
@ -138,42 +309,16 @@ class TestBartExamples(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
self.assertEqual(batch["source_mask"].shape, batch["source_ids"].shape)
|
self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape)
|
||||||
# show that articles were trimmed.
|
# show that articles were trimmed.
|
||||||
self.assertEqual(batch["source_ids"].shape[1], max_len_source)
|
self.assertEqual(batch["input_ids"].shape[1], max_len_source)
|
||||||
self.assertGreater(20, batch["source_ids"].shape[1]) # trimmed significantly
|
self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly
|
||||||
|
|
||||||
# show that targets were truncated
|
# show that targets were truncated
|
||||||
self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated
|
self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated
|
||||||
self.assertGreater(max_len_target, trunc_target) # Truncated
|
self.assertGreater(max_len_target, trunc_target) # Truncated
|
||||||
|
|
||||||
|
|
||||||
class TestT5Examples(unittest.TestCase):
|
def list_to_text_file(lst, path):
|
||||||
def test_t5_cli(self):
|
dest = Path(path)
|
||||||
output_file_name = "output_t5_sum.txt"
|
dest.open("w+").writelines(lst)
|
||||||
score_file_name = "score_t5_sum.txt"
|
|
||||||
articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
logger.addHandler(stream_handler)
|
|
||||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
|
|
||||||
with tmp.open("w", encoding="utf-8") as f:
|
|
||||||
f.write("\n".join(articles))
|
|
||||||
|
|
||||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
|
|
||||||
score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"
|
|
||||||
|
|
||||||
testargs = [
|
|
||||||
"evaluate_cnn.py",
|
|
||||||
str(tmp),
|
|
||||||
str(output_file_name),
|
|
||||||
"patrickvonplaten/t5-tiny-random",
|
|
||||||
"--reference_path",
|
|
||||||
str(tmp),
|
|
||||||
"--score_path",
|
|
||||||
str(score_file_name),
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
|
||||||
run_generate()
|
|
||||||
self.assertTrue(Path(output_file_name).exists())
|
|
||||||
self.assertTrue(Path(score_file_name).exists())
|
|
||||||
|
@ -1,20 +1,66 @@
|
|||||||
|
import itertools
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Iterable, List
|
||||||
|
|
||||||
|
import git
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from rouge_score import rouge_scorer, scoring
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import Dataset, Sampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from transformers import BartTokenizer
|
||||||
|
|
||||||
|
|
||||||
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
|
def encode_file(
|
||||||
|
tokenizer,
|
||||||
|
data_path,
|
||||||
|
max_length,
|
||||||
|
pad_to_max_length=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
overwrite_cache=False,
|
||||||
|
prefix="",
|
||||||
|
tok_name="",
|
||||||
|
):
|
||||||
|
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
||||||
|
if not overwrite_cache and cache_path.exists():
|
||||||
|
try:
|
||||||
|
examples = torch.load(cache_path)
|
||||||
|
assert isinstance(examples, list)
|
||||||
|
return examples
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
print(f"failed to load from {cache_path}, retokenizing {data_path}")
|
||||||
|
data_path = Path(data_path)
|
||||||
|
|
||||||
|
lns = lmap(str.strip, data_path.open().readlines())
|
||||||
|
lns = [prefix + text for text in lns]
|
||||||
|
assert lns, f"found empty file at {data_path}"
|
||||||
examples = []
|
examples = []
|
||||||
with open(data_path, "r") as f:
|
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
|
||||||
for text in f.readlines():
|
|
||||||
tokenized = tokenizer.batch_encode_plus(
|
tokenized = tokenizer.batch_encode_plus(
|
||||||
[text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
|
[text], # DONT ADD SPACES
|
||||||
|
max_length=max_length,
|
||||||
|
pad_to_max_length=pad_to_max_length,
|
||||||
|
add_prefix_space=True,
|
||||||
|
return_tensors=return_tensors,
|
||||||
)
|
)
|
||||||
examples.append(tokenized)
|
examples.append(tokenized)
|
||||||
|
torch.save(lmap(dict, examples), cache_path.open("wb"))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
def lmap(f, x):
|
||||||
|
return list(map(f, x))
|
||||||
|
|
||||||
|
|
||||||
|
T5_PREFIX = "summarize: " # HACK, fixme
|
||||||
|
|
||||||
|
|
||||||
def trim_batch(
|
def trim_batch(
|
||||||
input_ids, pad_token_id, attention_mask=None,
|
input_ids, pad_token_id, attention_mask=None,
|
||||||
):
|
):
|
||||||
@ -30,15 +76,38 @@ class SummarizationDataset(Dataset):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir="./cnn-dailymail/cnn_dm/",
|
data_dir,
|
||||||
type_path="train",
|
type_path="train",
|
||||||
max_source_length=1024,
|
max_source_length=1024,
|
||||||
max_target_length=56,
|
max_target_length=56,
|
||||||
|
n_obs=None,
|
||||||
|
overwrite_cache=False,
|
||||||
|
prefix="",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = tokenizer
|
tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else ""
|
||||||
|
self.source = encode_file(
|
||||||
|
tokenizer,
|
||||||
|
os.path.join(data_dir, type_path + ".source"),
|
||||||
|
max_source_length,
|
||||||
|
overwrite_cache=overwrite_cache,
|
||||||
|
prefix=prefix,
|
||||||
|
tok_name=tok_name,
|
||||||
|
)
|
||||||
|
if type_path == "train":
|
||||||
|
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||||
|
else:
|
||||||
|
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||||
|
|
||||||
|
self.target = encode_file(
|
||||||
|
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||||
|
)
|
||||||
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
|
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
|
||||||
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
|
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
|
||||||
|
if n_obs is not None:
|
||||||
|
self.source = self.source[:n_obs]
|
||||||
|
self.target = self.target[:n_obs]
|
||||||
|
self.pad_token_id = tokenizer.pad_token_id
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.source)
|
return len(self.source)
|
||||||
@ -47,19 +116,141 @@ class SummarizationDataset(Dataset):
|
|||||||
source_ids = self.source[index]["input_ids"].squeeze()
|
source_ids = self.source[index]["input_ids"].squeeze()
|
||||||
target_ids = self.target[index]["input_ids"].squeeze()
|
target_ids = self.target[index]["input_ids"].squeeze()
|
||||||
src_mask = self.source[index]["attention_mask"].squeeze()
|
src_mask = self.source[index]["attention_mask"].squeeze()
|
||||||
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
|
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trim_seq2seq_batch(batch, pad_token_id):
|
def trim_seq2seq_batch(batch, pad_token_id):
|
||||||
y = trim_batch(batch["target_ids"], pad_token_id)
|
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
|
||||||
source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"])
|
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
|
||||||
return source_ids, source_mask, y
|
return source_ids, source_mask, y
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch) -> dict:
|
||||||
input_ids = torch.stack([x["source_ids"] for x in batch])
|
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||||
masks = torch.stack([x["source_mask"] for x in batch])
|
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||||
target_ids = torch.stack([x["target_ids"] for x in batch])
|
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.pad_token_id
|
||||||
y = trim_batch(target_ids, pad_token_id)
|
y = trim_batch(target_ids, pad_token_id)
|
||||||
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||||
return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}
|
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@property
|
||||||
|
def src_lens(self): # Can delete?
|
||||||
|
return lmap(len, self.source)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tgt_lens(self):
|
||||||
|
return lmap(len, self.target)
|
||||||
|
|
||||||
|
def make_sortish_sampler(self, batch_size):
|
||||||
|
return SortishSampler(self.source, batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
class SortishSampler(Sampler):
|
||||||
|
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||||
|
|
||||||
|
def __init__(self, data, batch_size):
|
||||||
|
self.data, self.bs = data, batch_size
|
||||||
|
|
||||||
|
def key(self, i):
|
||||||
|
return len(self.data[i])
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
idxs = np.random.permutation(len(self.data))
|
||||||
|
sz = self.bs * 50
|
||||||
|
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
||||||
|
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
|
||||||
|
sz = self.bs
|
||||||
|
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
||||||
|
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
|
||||||
|
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
|
||||||
|
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
|
||||||
|
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
||||||
|
return iter(sort_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def use_task_specific_params(model, task):
|
||||||
|
# update config with summarization specific params
|
||||||
|
task_specific_params = model.config.task_specific_params
|
||||||
|
if task_specific_params is not None:
|
||||||
|
model.config.update(task_specific_params.get(task, {}))
|
||||||
|
|
||||||
|
|
||||||
|
def pickle_load(path):
|
||||||
|
"""pickle.load(path)"""
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def pickle_save(obj, path):
|
||||||
|
"""pickle.dump(obj, path)"""
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
return pickle.dump(obj, f)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_list(summary_ids: List[List]):
|
||||||
|
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||||
|
|
||||||
|
|
||||||
|
def save_git_info(folder_path: str):
|
||||||
|
"""
|
||||||
|
Log commit info.
|
||||||
|
"""
|
||||||
|
repo_infos = get_git_info()
|
||||||
|
|
||||||
|
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
|
||||||
|
json.dump(repo_infos, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_info():
|
||||||
|
repo = git.Repo(search_parent_directories=True)
|
||||||
|
repo_infos = {
|
||||||
|
"repo_id": str(repo),
|
||||||
|
"repo_sha": str(repo.head.object.hexsha),
|
||||||
|
"repo_branch": str(repo.active_branch),
|
||||||
|
}
|
||||||
|
return repo_infos
|
||||||
|
|
||||||
|
|
||||||
|
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||||
|
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
|
||||||
|
aggregator = scoring.BootstrapAggregator()
|
||||||
|
|
||||||
|
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||||
|
scores = scorer.score(reference_ln, output_ln)
|
||||||
|
aggregator.add_scores(scores)
|
||||||
|
|
||||||
|
result = aggregator.aggregate()
|
||||||
|
return {k: v.mid.fmeasure for k, v in result.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_params(model: nn.Module):
|
||||||
|
for par in model.parameters():
|
||||||
|
par.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
|
def grad_status(model: nn.Module) -> Iterable:
|
||||||
|
return (par.requires_grad for par in model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
def any_requires_grad(model: nn.Module) -> bool:
|
||||||
|
return any(grad_status(model))
|
||||||
|
|
||||||
|
|
||||||
|
def assert_all_frozen(model):
|
||||||
|
model_grads: List[bool] = list(grad_status(model))
|
||||||
|
n_require_grad = sum(lmap(int, model_grads))
|
||||||
|
npars = len(model_grads)
|
||||||
|
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_not_all_frozen(model):
|
||||||
|
model_grads: List[bool] = list(grad_status(model))
|
||||||
|
npars = len(model_grads)
|
||||||
|
assert any(model_grads), f"none of {npars} weights require grad"
|
||||||
|
@ -59,7 +59,7 @@ BART_GENERATION_EXAMPLE = r"""
|
|||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
|
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
|
||||||
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
|
# see ``examples/summarization/bart/run_eval.py`` for a longer example
|
||||||
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
|
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
|
||||||
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
|
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
|
||||||
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||||
|
Loading…
Reference in New Issue
Block a user