diff --git a/examples/seq2seq/callbacks.py b/examples/seq2seq/callbacks.py index 5942001b3d8..45785e681cf 100644 --- a/examples/seq2seq/callbacks.py +++ b/examples/seq2seq/callbacks.py @@ -75,21 +75,23 @@ class Seq2SeqLoggingCallback(pl.Callback): return self._write_logs(trainer, pl_module, "test") -def get_checkpoint_callback(output_dir, metric, save_top_k=1): +def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False): """Saves the best model by validation ROUGE2 score.""" if metric == "rouge2": exp = "{val_avg_rouge2:.4f}-{step_count}" elif metric == "bleu": exp = "{val_avg_bleu:.4f}-{step_count}" + elif metric == "loss": + exp = "{val_avg_loss:.4f}-{step_count}" else: raise NotImplementedError( - f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." + f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function." ) checkpoint_callback = ModelCheckpoint( filepath=os.path.join(output_dir, exp), monitor=f"val_{metric}", - mode="max", + mode="min" if "loss" in metric else "max", save_top_k=save_top_k, period=0, # maybe save a checkpoint every time val is run, not just end of epoch. ) @@ -98,8 +100,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1): def get_early_stopping_callback(metric, patience): return EarlyStopping( - monitor=f"val_{metric}", - mode="max", + monitor=f"val_{metric}", # does this need avg? + mode="min" if "loss" in metric else "max", patience=patience, verbose=True, ) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 8181b8e8602..73b69d02b34 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -148,10 +148,10 @@ class SummarizationModule(BaseTransformer): lm_logits = outputs[0] if self.hparams.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) + ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) assert lm_logits.shape[-1] == self.model.config.vocab_size - loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) + loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( @@ -178,15 +178,25 @@ class SummarizationModule(BaseTransformer): self.step_count += 1 losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} loss = losses["loss"] - rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]} - rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss) - rouges.update({k: v.item() for k, v in losses.items()}) - losses.update(rouges) - metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} - metrics["step_count"] = self.step_count - self.save_metrics(metrics, prefix) # writes to self.metrics_save_path + generative_metrics = { + k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"] + } + metric_val = ( + generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric] + ) + metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss) + generative_metrics.update({k: v.item() for k, v in losses.items()}) + losses.update(generative_metrics) + all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} + all_metrics["step_count"] = self.step_count + self.save_metrics(all_metrics, prefix) # writes to self.metrics_save_path preds = flatten_list([x["preds"] for x in outputs]) - return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor} + return { + "log": all_metrics, + "preds": preds, + f"{prefix}_loss": loss, + f"{prefix}_{self.val_metric}": metric_tensor, + } def save_metrics(self, latest_metrics, type_path) -> None: self.metrics[type_path].append(latest_metrics) @@ -306,7 +316,9 @@ class SummarizationModule(BaseTransformer): parser.add_argument("--src_lang", type=str, default="", required=False) parser.add_argument("--tgt_lang", type=str, default="", required=False) parser.add_argument("--eval_beams", type=int, default=None, required=False) - parser.add_argument("--val_metric", type=str, default=None, required=False) + parser.add_argument( + "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None] + ) parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save") parser.add_argument( "--early_stopping_patience", @@ -366,14 +378,17 @@ def main(args, model=None) -> SummarizationModule: es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience) else: es_callback = False + + lower_is_better = args.val_metric == "loss" trainer: pl.Trainer = generic_train( model, args, logging_callback=Seq2SeqLoggingCallback(), - checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric, args.save_top_k), + checkpoint_callback=get_checkpoint_callback( + args.output_dir, model.val_metric, args.save_top_k, lower_is_better + ), early_stopping_callback=es_callback, logger=logger, - # TODO: early stopping callback seems messed up ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") if not args.do_predict: diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 2ecc7b88837..92d74eeefa2 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -33,7 +33,7 @@ CUDA_AVAILABLE = torch.cuda.is_available() CHEAP_ARGS = { "label_smoothing": 0.2, "eval_beams": 1, - "val_metric": None, + "val_metric": "loss", "save_top_k": 1, "adafactor": True, "early_stopping_patience": 2, @@ -262,9 +262,9 @@ class TestSummarizationDistiller(unittest.TestCase): if not check_contents: return model contents = os.listdir(output_dir) - ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt" contents = {os.path.basename(p) for p in contents} - self.assertIn(ckpt_name, contents) + ckpt_files = [p for p in contents if p.endswith("ckpt")] + assert len(ckpt_files) > 0 self.assertIn("test_generations.txt", contents) self.assertIn("test_results.txt", contents)