Add --text_column to run_summarization_no_trainer (#11673)

This commit is contained in:
Jonathan Chang 2021-05-11 19:58:38 +08:00 committed by GitHub
parent 024cd19bb7
commit 64232bc0df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -184,6 +184,12 @@ def parse_args():
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--text_column",
type=str,
default=None,
help="The name of the column in the datasets containing the full texts (for summarization).",
)
parser.add_argument(
"--summary_column",
type=str,
@ -371,9 +377,14 @@ def main():
# Get the column names for input/target.
dataset_columns = summarization_name_mapping.get(args.dataset_name, None)
text_column_name = dataset_columns[0] if dataset_columns is not None else column_names[0]
padding = "max_length" if args.pad_to_max_length else False
if args.text_column is None:
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
text_column = args.text_column
if text_column not in column_names:
raise ValueError(
f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}"
)
if args.summary_column is None:
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
@ -388,7 +399,7 @@ def main():
padding = "max_length" if args.pad_to_max_length else False
def preprocess_function(examples):
inputs = examples[text_column_name]
inputs = examples[text_column]
targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)