diff --git a/examples/summarization/bart/finetune.py b/examples/summarization/bart/finetune.py index 58cdccf443d..9e3d55b3e94 100644 --- a/examples/summarization/bart/finetune.py +++ b/examples/summarization/bart/finetune.py @@ -166,8 +166,12 @@ def main(args): # Optionally, predict on dev set and write to output_dir if args.do_predict: + # See https://github.com/huggingface/transformers/issues/3159 + # pl use this format to create a checkpoint: + # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ + # /pytorch_lightning/callbacks/model_checkpoint.py#L169 checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) - SummarizationTrainer.load_from_checkpoint(checkpoints[-1]) + model = model.load_from_checkpoint(checkpoints[-1]) trainer.test(model) diff --git a/examples/summarization/bart/test_bart_examples.py b/examples/summarization/bart/test_bart_examples.py index ae137975964..199076d84f6 100644 --- a/examples/summarization/bart/test_bart_examples.py +++ b/examples/summarization/bart/test_bart_examples.py @@ -94,7 +94,15 @@ class TestBartExamples(unittest.TestCase): ) main(argparse.Namespace(**args_d)) args_d.update({"do_train": False, "do_predict": True}) + main(argparse.Namespace(**args_d)) + contents = os.listdir(output_dir) + expected_contents = { + "checkpointepoch=0.ckpt", + "test_results.txt", + } + created_files = {os.path.basename(p) for p in contents} + self.assertSetEqual(expected_contents, created_files) def test_t5_run_sum_cli(self): args_d: dict = DEFAULT_ARGS.copy() @@ -111,6 +119,7 @@ class TestBartExamples(unittest.TestCase): do_predict=True, ) main(argparse.Namespace(**args_d)) + # args_d.update({"do_train": False, "do_predict": True}) # main(argparse.Namespace(**args_d)) diff --git a/examples/transformer_base.py b/examples/transformer_base.py index a3b81610ea6..0ab355ad657 100644 --- a/examples/transformer_base.py +++ b/examples/transformer_base.py @@ -1,3 +1,4 @@ +import argparse import logging import os import random @@ -38,7 +39,7 @@ MODEL_MODES = { } -def set_seed(args): +def set_seed(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -47,7 +48,7 @@ def set_seed(args): class BaseTransformer(pl.LightningModule): - def __init__(self, hparams, num_labels=None, mode="base", **config_kwargs): + def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs): "Initialize a model." super(BaseTransformer, self).__init__() @@ -192,7 +193,7 @@ class BaseTransformer(pl.LightningModule): class LoggingCallback(pl.Callback): - def on_validation_end(self, trainer, pl_module): + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): logger.info("***** Validation results *****") if pl_module.is_logger(): metrics = trainer.callback_metrics @@ -201,7 +202,7 @@ class LoggingCallback(pl.Callback): if key not in ["log", "progress_bar"]: logger.info("{} = {}\n".format(key, str(metrics[key]))) - def on_test_end(self, trainer, pl_module): + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): logger.info("***** Test results *****") if pl_module.is_logger(): @@ -256,7 +257,7 @@ def add_generic_args(parser, root_dir): parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") -def generic_train(model, args): +def generic_train(model: BaseTransformer, args: argparse.Namespace): # init model set_seed(args)