mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
[examples] fix summarization do_predict (#3866)
This commit is contained in:
parent
52c85f847a
commit
a504cb49ec
@ -166,8 +166,12 @@ def main(args):
|
|||||||
|
|
||||||
# Optionally, predict on dev set and write to output_dir
|
# Optionally, predict on dev set and write to output_dir
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
|
# See https://github.com/huggingface/transformers/issues/3159
|
||||||
|
# pl use this format to create a checkpoint:
|
||||||
|
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
||||||
|
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
|
||||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
||||||
SummarizationTrainer.load_from_checkpoint(checkpoints[-1])
|
model = model.load_from_checkpoint(checkpoints[-1])
|
||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,7 +94,15 @@ class TestBartExamples(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
main(argparse.Namespace(**args_d))
|
main(argparse.Namespace(**args_d))
|
||||||
args_d.update({"do_train": False, "do_predict": True})
|
args_d.update({"do_train": False, "do_predict": True})
|
||||||
|
|
||||||
main(argparse.Namespace(**args_d))
|
main(argparse.Namespace(**args_d))
|
||||||
|
contents = os.listdir(output_dir)
|
||||||
|
expected_contents = {
|
||||||
|
"checkpointepoch=0.ckpt",
|
||||||
|
"test_results.txt",
|
||||||
|
}
|
||||||
|
created_files = {os.path.basename(p) for p in contents}
|
||||||
|
self.assertSetEqual(expected_contents, created_files)
|
||||||
|
|
||||||
def test_t5_run_sum_cli(self):
|
def test_t5_run_sum_cli(self):
|
||||||
args_d: dict = DEFAULT_ARGS.copy()
|
args_d: dict = DEFAULT_ARGS.copy()
|
||||||
@ -111,6 +119,7 @@ class TestBartExamples(unittest.TestCase):
|
|||||||
do_predict=True,
|
do_predict=True,
|
||||||
)
|
)
|
||||||
main(argparse.Namespace(**args_d))
|
main(argparse.Namespace(**args_d))
|
||||||
|
|
||||||
# args_d.update({"do_train": False, "do_predict": True})
|
# args_d.update({"do_train": False, "do_predict": True})
|
||||||
# main(argparse.Namespace(**args_d))
|
# main(argparse.Namespace(**args_d))
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -38,7 +39,7 @@ MODEL_MODES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
def set_seed(args: argparse.Namespace):
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@ -47,7 +48,7 @@ def set_seed(args):
|
|||||||
|
|
||||||
|
|
||||||
class BaseTransformer(pl.LightningModule):
|
class BaseTransformer(pl.LightningModule):
|
||||||
def __init__(self, hparams, num_labels=None, mode="base", **config_kwargs):
|
def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs):
|
||||||
"Initialize a model."
|
"Initialize a model."
|
||||||
|
|
||||||
super(BaseTransformer, self).__init__()
|
super(BaseTransformer, self).__init__()
|
||||||
@ -192,7 +193,7 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
|
|
||||||
|
|
||||||
class LoggingCallback(pl.Callback):
|
class LoggingCallback(pl.Callback):
|
||||||
def on_validation_end(self, trainer, pl_module):
|
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
logger.info("***** Validation results *****")
|
logger.info("***** Validation results *****")
|
||||||
if pl_module.is_logger():
|
if pl_module.is_logger():
|
||||||
metrics = trainer.callback_metrics
|
metrics = trainer.callback_metrics
|
||||||
@ -201,7 +202,7 @@ class LoggingCallback(pl.Callback):
|
|||||||
if key not in ["log", "progress_bar"]:
|
if key not in ["log", "progress_bar"]:
|
||||||
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
||||||
|
|
||||||
def on_test_end(self, trainer, pl_module):
|
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
logger.info("***** Test results *****")
|
logger.info("***** Test results *****")
|
||||||
|
|
||||||
if pl_module.is_logger():
|
if pl_module.is_logger():
|
||||||
@ -256,7 +257,7 @@ def add_generic_args(parser, root_dir):
|
|||||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
|
|
||||||
|
|
||||||
def generic_train(model, args):
|
def generic_train(model: BaseTransformer, args: argparse.Namespace):
|
||||||
# init model
|
# init model
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user