From dad5e12e54bc2cf80a24b3430b5c847fc213a73e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sat, 18 Jul 2020 07:43:57 -0400 Subject: [PATCH] [seq2seq] distillation.py accepts trainer arguments (#5865) --- examples/seq2seq/README.md | 2 +- examples/seq2seq/distillation.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index f726c63d3b5..e13ac7980ee 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -149,7 +149,7 @@ If 'translation' is in your task name, the computed metric will be BLEU. Otherwi For t5, you need to specify --task translation_{src}_to_{tgt} as follows: ```bash export DATA_DIR=wmt_en_ro -python run_eval.py t5_base \ +python run_eval.py t5-base \ $DATA_DIR/val.source t5_val_generations.txt \ --reference_path $DATA_DIR/val.target \ --score_path enro_bleu.json \ diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 30e2e1c7556..a683fd7e057 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -446,6 +446,7 @@ def distill_main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args()