From b1dade34db9a667446c12cd0c3f93e598136b476 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Thu, 1 Nov 2018 01:05:11 -0400 Subject: [PATCH] Convert flags to argparse in `run_classifier_pytorch.py` --- run_classifier_pytorch.py | 145 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 run_classifier_pytorch.py diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py new file mode 100644 index 00000000000..7931b6b1b52 --- /dev/null +++ b/run_classifier_pytorch.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT finetuning runner.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# import csv +# import os +# import modeling_pytorch +# import optimization +# import tokenization + +import argparse + +parser = argparse.ArgumentParser() + +## Required parameters +parser.add_argument("--data_dir", + default = None, + type = str, + required = True, + help = "The input data dir. Should contain the .tsv files (or other data files) for the task.") +parser.add_argument("--bert_config_file", + default = None, + type = str, + required = True, + help = "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.") +parser.add_argument("--task_name", + default = None, + type = str, + required = True, + help = "The name of the task to train.") +parser.add_argument("--vocab_file", + default = None, + type = str, + required = True, + help = "The vocabulary file that the BERT model was trained on.") +parser.add_argument("--output_dir", + default = None, + type = str, + required = True, + help = "The output directory where the model checkpoints will be written.") + +## Other parameters +parser.add_argument("--init_checkpoint", + default = None, + type = str, + help = "Initial checkpoint (usually from a pre-trained BERT model).") +parser.add_argument("--do_lower_case", + default = True, + type = bool, + help = "Whether to lower case the input text. Should be True for uncased models and False for cased models.") +parser.add_argument("--max_seq_length", + default = 128, + type = int, + help = "The maximum total input sequence length after WordPiece tokenization. \n" + "Sequences longer than this will be truncated, and sequences shorter \n" + "than this will be padded.") +parser.add_argument("--do_train", + default = False, + type = bool, + help = "Whether to run training.") +parser.add_argument("--do_eval", + default = False, + type = bool, + help = "Whether to run eval on the dev set.") +parser.add_argument("--train_batch_size", + default = 32, + type = int, + help = "Total batch size for training.") +parser.add_argument("--eval_batch_size", + default = 8, + type = int, + help = "Total batch size for eval.") +parser.add_argument("--learning_rate", + default = 5e-5, + type = float, + help = "The initial learning rate for Adam.") +parser.add_argument("--num_train_epochs", + default = 3.0, + type = float, + help = "Total number of training epochs to perform.") +parser.add_argument("--warmup_proportion", + default = 0.1, + type = float, + help = "Proportion of training to perform linear learning rate warmup for. " + "E.g., 0.1 = 10%% of training.") +parser.add_argument("--save_checkpoints_steps", + default = 1000, + type = int, + help = "How often to save the model checkpoint.") +parser.add_argument("--iterations_per_loop", + default = 1000, + type = int, + help = "How many steps to make in each estimator call.") + +### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ### +parser.add_argument("--use_tpu", + default = False, + type = bool, + help = "Whether to use TPU or GPU/CPU.") +parser.add_argument("--tpu_name", + default = None, + type = str, + help = "The Cloud TPU to use for training. This should be either the name " + "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " + "url.") +parser.add_argument("--tpu_zone", + default = None, + type = str, + help = "[Optional] GCE zone where the Cloud TPU is located in. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") +parser.add_argument("--gcp_project", + default = None, + type = str, + help = "[Optional] Project name for the Cloud TPU-enabled project. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") +parser.add_argument("--master", + default = None, + type = str, + help = "[Optional] TensorFlow master URL.") +parser.add_argument("--num_tpu_cores", + default = 8, + type = int, + help = "Only used if `use_tpu` is True. Total number of TPU cores to use.") +### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ### + +args = parser.parse_args() \ No newline at end of file