mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix quality
This commit is contained in:
parent
73a532651a
commit
d72e5a3a6d
@ -76,10 +76,16 @@ def parse_args():
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_column_name", type=str, default=None, help="The column name of text to input in the file (a csv or JSON file)."
|
||||
"--text_column_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The column name of text to input in the file (a csv or JSON file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label_column_name", type=str, default=None, help="The column name of label to input in the file (a csv or JSON file)."
|
||||
"--label_column_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The column name of label to input in the file (a csv or JSON file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
@ -266,17 +272,17 @@ def main():
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
features = raw_datasets["validation"].features
|
||||
|
||||
if data_args.text_column_name is not None:
|
||||
text_column_name = data_args.text_column_name
|
||||
if args.text_column_name is not None:
|
||||
text_column_name = args.text_column_name
|
||||
elif "tokens" in column_names:
|
||||
text_column_name = "tokens"
|
||||
else:
|
||||
text_column_name = column_names[0]
|
||||
|
||||
if data_args.label_column_name is not None:
|
||||
label_column_name = data_args.label_column_name
|
||||
elif f"{data_args.task_name}_tags" in column_names:
|
||||
label_column_name = f"{data_args.task_name}_tags"
|
||||
if args.label_column_name is not None:
|
||||
label_column_name = args.label_column_name
|
||||
elif f"{args.task_name}_tags" in column_names:
|
||||
label_column_name = f"{args.task_name}_tags"
|
||||
else:
|
||||
label_column_name = column_names[1]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user