diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index 07b2f9e2d45..c6f86cca471 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -76,10 +76,16 @@ def parse_args(): "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." ) parser.add_argument( - "--text_column_name", type=str, default=None, help="The column name of text to input in the file (a csv or JSON file)." + "--text_column_name", + type=str, + default=None, + help="The column name of text to input in the file (a csv or JSON file).", ) parser.add_argument( - "--label_column_name", type=str, default=None, help="The column name of label to input in the file (a csv or JSON file)." + "--label_column_name", + type=str, + default=None, + help="The column name of label to input in the file (a csv or JSON file).", ) parser.add_argument( "--max_length", @@ -266,17 +272,17 @@ def main(): column_names = raw_datasets["validation"].column_names features = raw_datasets["validation"].features - if data_args.text_column_name is not None: - text_column_name = data_args.text_column_name + if args.text_column_name is not None: + text_column_name = args.text_column_name elif "tokens" in column_names: text_column_name = "tokens" else: text_column_name = column_names[0] - if data_args.label_column_name is not None: - label_column_name = data_args.label_column_name - elif f"{data_args.task_name}_tags" in column_names: - label_column_name = f"{data_args.task_name}_tags" + if args.label_column_name is not None: + label_column_name = args.label_column_name + elif f"{args.task_name}_tags" in column_names: + label_column_name = f"{args.task_name}_tags" else: label_column_name = column_names[1]