import argparse import glob import logging import os import time import warnings from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple import numpy as np import pytorch_lightning as pl import torch from torch.utils.data import DataLoader from lightning_base import BaseTransformer, add_generic_args, generic_train from transformers import get_linear_schedule_with_warmup try: from .utils import ( use_task_specific_params, SummarizationDataset, lmap, flatten_list, pickle_save, save_git_info, save_json, freeze_params, calculate_rouge, get_git_info, ROUGE_KEYS, calculate_bleu_score, ) from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback except ImportError: from utils import ( use_task_specific_params, SummarizationDataset, lmap, flatten_list, pickle_save, save_git_info, save_json, freeze_params, calculate_rouge, get_git_info, ROUGE_KEYS, calculate_bleu_score, ) from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback logger = logging.getLogger(__name__) class SummarizationModule(BaseTransformer): mode = "summarization" loss_names = ["loss"] metric_names = ROUGE_KEYS val_metric = "rouge2" def __init__(self, hparams, **kwargs): super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) use_task_specific_params(self.model, "summarization") save_git_info(self.hparams.output_dir) self.metrics_save_path = Path(self.output_dir) / "metrics.json" self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" pickle_save(self.hparams, self.hparams_save_path) self.step_count = 0 self.metrics = defaultdict(list) self.dataset_kwargs: dict = dict( data_dir=self.hparams.data_dir, max_source_length=self.hparams.max_source_length, prefix=self.model.config.prefix or "", ) n_observations_per_split = { "train": self.hparams.n_train, "val": self.hparams.n_val, "test": self.hparams.n_test, } self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()} self.target_lens = { "train": self.hparams.max_target_length, "val": self.hparams.val_max_target_length, "test": self.hparams.test_max_target_length, } assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" if self.hparams.freeze_embeds: self.freeze_embeds() if self.hparams.freeze_encoder: freeze_params(self.model.model.encoder) # TODO: this will break for t5 self.hparams.git_sha = get_git_info()["repo_sha"] self.num_workers = hparams.num_workers def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" try: freeze_params(self.model.model.shared) for d in [self.model.model.encoder, self.model.model.decoder]: freeze_params(d.embed_positions) freeze_params(d.embed_tokens) except AttributeError: freeze_params(self.model.shared) for d in [self.model.encoder, self.model.decoder]: freeze_params(d.embed_tokens) def forward(self, input_ids, **kwargs): return self.model(input_ids, **kwargs) def ids_to_clean_text(self, generated_ids: List[int]): gen_text = self.tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) return lmap(str.strip, gen_text) def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] y_ids = y[:, :-1].contiguous() lm_labels = y[:, 1:].clone() lm_labels[y[:, 1:] == pad_token_id] = -100 outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,) loss = outputs[0] return (loss,) def training_step(self, batch, batch_idx) -> Dict: loss_tensors = self._step(batch) logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} return {"loss": loss_tensors[0], "log": logs} def validation_step(self, batch, batch_idx) -> Dict: return self._generative_step(batch) def validation_epoch_end(self, outputs, prefix="val") -> Dict: self.step_count += 1 losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} loss = losses["loss"] rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]} rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss) rouges.update({k: v.item() for k, v in losses.items()}) losses.update(rouges) metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} metrics["step_count"] = self.step_count self.save_metrics(metrics, prefix) # writes to self.metrics_save_path preds = flatten_list([x["preds"] for x in outputs]) return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor} def save_metrics(self, latest_metrics, type_path) -> None: self.metrics[type_path].append(latest_metrics) save_json(self.metrics, self.metrics_save_path) def calc_generative_metrics(self, preds, target) -> Dict: return calculate_rouge(preds, target) def _generative_step(self, batch: dict) -> dict: pad_token_id = self.tokenizer.pad_token_id source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) t0 = time.time() generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,) gen_time = (time.time() - t0) / source_ids.shape[0] preds = self.ids_to_clean_text(generated_ids) target = self.ids_to_clean_text(y) loss_tensors = self._step(batch) base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} rouge: Dict = self.calc_generative_metrics(preds, target) summ_len = np.mean(lmap(len, generated_ids)) base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge) return base_metrics def test_step(self, batch, batch_idx): return self._generative_step(batch) def test_epoch_end(self, outputs): return self.validation_epoch_end(outputs, prefix="test") def get_dataset(self, type_path) -> SummarizationDataset: n_obs = self.n_obs[type_path] max_target_length = self.target_lens[type_path] dataset = SummarizationDataset( self.tokenizer, type_path=type_path, n_obs=n_obs, max_target_length=max_target_length, **self.dataset_kwargs, ) return dataset def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: dataset = self.get_dataset(type_path) sampler = None if self.hparams.sortish_sampler and type_path == "train": assert self.hparams.gpus <= 1 # TODO: assert earlier sampler = dataset.make_sortish_sampler(batch_size) shuffle = False dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle, num_workers=self.num_workers, sampler=sampler, ) return dataloader def train_dataloader(self) -> DataLoader: dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) t_total = ( (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus))) // self.hparams.gradient_accumulation_steps * float(self.hparams.num_train_epochs) ) scheduler = get_linear_schedule_with_warmup( self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total ) if max(scheduler.get_last_lr()) > 0: warnings.warn("All learning rates are 0") self.lr_scheduler = scheduler return dataloader def val_dataloader(self) -> DataLoader: return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size) def test_dataloader(self) -> DataLoader: return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size) @staticmethod def add_model_specific_args(parser, root_dir): BaseTransformer.add_model_specific_args(parser, root_dir) add_generic_args(parser, root_dir) parser.add_argument( "--max_source_length", default=1024, type=int, help="The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.", ) parser.add_argument( "--max_target_length", default=56, type=int, help="The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.", ) parser.add_argument( "--val_max_target_length", default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. type=int, help="The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.", ) parser.add_argument( "--test_max_target_length", default=142, type=int, help="The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.", ) parser.add_argument( "--data_dir", type=str, required=True, help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target", ) parser.add_argument("--freeze_encoder", action="store_true") parser.add_argument("--freeze_embeds", action="store_true") parser.add_argument("--sortish_sampler", action="store_true", default=False) parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default") parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument( "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." ) return parser class TranslationModule(SummarizationModule): mode = "translation" loss_names = ["loss"] metric_names = ["bleu"] val_metric = "bleu" def calc_generative_metrics(self, preds, target) -> dict: return calculate_bleu_score(preds, target) def main(args, model=None) -> SummarizationModule: Path(args.output_dir).mkdir(exist_ok=True) if len(os.listdir(args.output_dir)) > 3 and args.do_train: raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if model is None: if args.task == "summarization": model: SummarizationModule = SummarizationModule(args) else: model: SummarizationModule = TranslationModule(args) if ( args.logger == "default" or args.fast_dev_run or str(args.output_dir).startswith("/tmp") or str(args.output_dir).startswith("/var") ): logger = True # don't pollute wandb logs unnecessarily elif args.logger == "wandb": from pytorch_lightning.loggers import WandbLogger logger = WandbLogger(name=model.output_dir.name) elif args.logger == "wandb_shared": from pytorch_lightning.loggers import WandbLogger logger = WandbLogger(name=model.output_dir.name) trainer: pl.Trainer = generic_train( model, args, logging_callback=Seq2SeqLoggingCallback(), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), logger=logger, # TODO: early stopping callback seems messed up ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") if not args.do_predict: return model model.hparams.test_checkpoint = "" checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) if checkpoints: model.hparams.test_checkpoint = checkpoints[-1] trainer.resume_from_checkpoint = checkpoints[-1] trainer.logger.log_hyperparams(model.hparams) trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics. return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args() main(args)