Fix quality

This commit is contained in:
Sylvain Gugger 2021-06-10 09:27:11 -04:00
parent 73a532651a
commit d72e5a3a6d

View File

@ -76,10 +76,16 @@ def parse_args():
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
parser.add_argument(
"--text_column_name", type=str, default=None, help="The column name of text to input in the file (a csv or JSON file)."
"--text_column_name",
type=str,
default=None,
help="The column name of text to input in the file (a csv or JSON file).",
)
parser.add_argument(
"--label_column_name", type=str, default=None, help="The column name of label to input in the file (a csv or JSON file)."
"--label_column_name",
type=str,
default=None,
help="The column name of label to input in the file (a csv or JSON file).",
)
parser.add_argument(
"--max_length",
@ -266,17 +272,17 @@ def main():
column_names = raw_datasets["validation"].column_names
features = raw_datasets["validation"].features
if data_args.text_column_name is not None:
text_column_name = data_args.text_column_name
if args.text_column_name is not None:
text_column_name = args.text_column_name
elif "tokens" in column_names:
text_column_name = "tokens"
else:
text_column_name = column_names[0]
if data_args.label_column_name is not None:
label_column_name = data_args.label_column_name
elif f"{data_args.task_name}_tags" in column_names:
label_column_name = f"{data_args.task_name}_tags"
if args.label_column_name is not None:
label_column_name = args.label_column_name
elif f"{args.task_name}_tags" in column_names:
label_column_name = f"{args.task_name}_tags"
else:
label_column_name = column_names[1]