From 64232bc0df7e28f91bdad2b29fca1808089e3dfd Mon Sep 17 00:00:00 2001 From: Jonathan Chang <31893406+cccntu@users.noreply.github.com> Date: Tue, 11 May 2021 19:58:38 +0800 Subject: [PATCH] Add --text_column to run_summarization_no_trainer (#11673) --- .../run_summarization_no_trainer.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index 7bd2edd6dd6..ab204907d4c 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -184,6 +184,12 @@ def parse_args(): default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) + parser.add_argument( + "--text_column", + type=str, + default=None, + help="The name of the column in the datasets containing the full texts (for summarization).", + ) parser.add_argument( "--summary_column", type=str, @@ -371,9 +377,14 @@ def main(): # Get the column names for input/target. dataset_columns = summarization_name_mapping.get(args.dataset_name, None) - text_column_name = dataset_columns[0] if dataset_columns is not None else column_names[0] - - padding = "max_length" if args.pad_to_max_length else False + if args.text_column is None: + text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + text_column = args.text_column + if text_column not in column_names: + raise ValueError( + f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}" + ) if args.summary_column is None: summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] else: @@ -388,7 +399,7 @@ def main(): padding = "max_length" if args.pad_to_max_length else False def preprocess_function(examples): - inputs = examples[text_column_name] + inputs = examples[text_column] targets = examples[summary_column] inputs = [prefix + inp for inp in inputs] model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)