transformers/examples/seq2seq/finetune.py

410 lines
16 KiB
Python

import argparse
import glob
import logging
import os
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
try:
from .utils import (
assert_all_frozen,
use_task_specific_params,
lmap,
flatten_list,
pickle_save,
save_git_info,
save_json,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
Seq2SeqDataset,
MBartDataset,
label_smoothed_nll_loss,
)
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
except ImportError:
from utils import (
Seq2SeqDataset,
MBartDataset,
assert_all_frozen,
use_task_specific_params,
lmap,
flatten_list,
pickle_save,
save_git_info,
save_json,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
label_smoothed_nll_loss,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
logger = logging.getLogger(__name__)
class SummarizationModule(BaseTransformer):
mode = "summarization"
loss_names = ["loss"]
metric_names = ROUGE_KEYS
val_metric = "rouge2"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
use_task_specific_params(self.model, "summarization")
save_git_info(self.hparams.output_dir)
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
pickle_save(self.hparams, self.hparams_save_path)
self.step_count = 0
self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=self.model.config.prefix or "",
)
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
"test": self.hparams.n_test,
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
self.target_lens = {
"train": self.hparams.max_target_length,
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
if self.hparams.freeze_embeds:
self.freeze_embeds()
if self.hparams.freeze_encoder:
freeze_params(self.model.get_encoder())
assert_all_frozen(self.model.get_encoder())
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None
self.dataset_class = Seq2SeqDataset
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
try:
freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
except AttributeError:
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def ids_to_clean_text(self, generated_ids: List[int]):
gen_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return lmap(str.strip, gen_text)
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(target_ids)
lm_labels = target_ids
else:
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
lm_logits = outputs[0]
assert lm_logits.shape[-1] == self.model.config.vocab_size
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
)
return (loss,)
def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch)
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
return {"loss": loss_tensors[0], "log": logs}
def validation_step(self, batch, batch_idx) -> Dict:
return self._generative_step(batch)
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
rouges.update({k: v.item() for k, v in losses.items()})
losses.update(rouges)
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
metrics["step_count"] = self.step_count
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs])
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
def save_metrics(self, latest_metrics, type_path) -> None:
self.metrics[type_path].append(latest_metrics)
save_json(self.metrics, self.metrics_save_path)
def calc_generative_metrics(self, preds, target) -> Dict:
return calculate_rouge(preds, target)
def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time()
generated_ids = self.model.generate(
input_ids=source_ids,
attention_mask=source_mask,
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
)
gen_time = (time.time() - t0) / source_ids.shape[0]
preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y)
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
return base_metrics
def test_step(self, batch, batch_idx):
return self._generative_step(batch)
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> Seq2SeqDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = self.dataset_class(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
max_target_length=max_target_length,
**self.dataset_kwargs,
)
return dataset
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
num_workers=self.num_workers,
sampler=sampler,
)
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self) -> DataLoader:
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
def test_dataloader(self) -> DataLoader:
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
@staticmethod
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
add_generic_args(parser, root_dir)
parser.add_argument(
"--max_source_length",
default=1024,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--max_target_length",
default=56,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--val_max_target_length",
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--test_max_target_length",
default=142,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
parser.add_argument(
"--early_stopping_patience",
type=int,
default=-1,
required=False,
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
)
return parser
class TranslationModule(SummarizationModule):
mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"]
val_metric = "bleu"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.dataset_kwargs["src_lang"] = hparams.src_lang
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)
def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if model is None:
if args.task == "summarization":
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if (
args.logger_name == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
project = os.environ.get("WANDB_PROJECT", dataset)
logger = WandbLogger(name=model.output_dir.name, project=project)
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
if args.early_stopping_patience >= 0:
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
else:
es_callback = False
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=logger,
# TODO: early stopping callback seems messed up
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
if not args.do_predict:
return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]
trainer.logger.log_hyperparams(model.hparams)
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)