mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add --text_column to run_summarization_no_trainer (#11673)
This commit is contained in:
parent
024cd19bb7
commit
64232bc0df
@ -184,6 +184,12 @@ def parse_args():
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_column",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the column in the datasets containing the full texts (for summarization).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summary_column",
|
||||
type=str,
|
||||
@ -371,9 +377,14 @@ def main():
|
||||
|
||||
# Get the column names for input/target.
|
||||
dataset_columns = summarization_name_mapping.get(args.dataset_name, None)
|
||||
text_column_name = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
|
||||
padding = "max_length" if args.pad_to_max_length else False
|
||||
if args.text_column is None:
|
||||
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
else:
|
||||
text_column = args.text_column
|
||||
if text_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
if args.summary_column is None:
|
||||
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
||||
else:
|
||||
@ -388,7 +399,7 @@ def main():
|
||||
padding = "max_length" if args.pad_to_max_length else False
|
||||
|
||||
def preprocess_function(examples):
|
||||
inputs = examples[text_column_name]
|
||||
inputs = examples[text_column]
|
||||
targets = examples[summary_column]
|
||||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
|
||||
|
Loading…
Reference in New Issue
Block a user