transformers/run_classifier_pytorch.py

145 lines
6.5 KiB
Python

# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# import csv
# import os
# import modeling_pytorch
# import optimization
# import tokenization
import argparse
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir",
default = None,
type = str,
required = True,
help = "The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--task_name",
default = None,
type = str,
required = True,
help = "The name of the task to train.")
parser.add_argument("--vocab_file",
default = None,
type = str,
required = True,
help = "The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_dir",
default = None,
type = str,
required = True,
help = "The output directory where the model checkpoints will be written.")
## Other parameters
parser.add_argument("--init_checkpoint",
default = None,
type = str,
help = "Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--do_lower_case",
default = True,
type = bool,
help = "Whether to lower case the input text. Should be True for uncased models and False for cased models.")
parser.add_argument("--max_seq_length",
default = 128,
type = int,
help = "The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_train",
default = False,
type = bool,
help = "Whether to run training.")
parser.add_argument("--do_eval",
default = False,
type = bool,
help = "Whether to run eval on the dev set.")
parser.add_argument("--train_batch_size",
default = 32,
type = int,
help = "Total batch size for training.")
parser.add_argument("--eval_batch_size",
default = 8,
type = int,
help = "Total batch size for eval.")
parser.add_argument("--learning_rate",
default = 5e-5,
type = float,
help = "The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs",
default = 3.0,
type = float,
help = "Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",
default = 0.1,
type = float,
help = "Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--save_checkpoints_steps",
default = 1000,
type = int,
help = "How often to save the model checkpoint.")
parser.add_argument("--iterations_per_loop",
default = 1000,
type = int,
help = "How many steps to make in each estimator call.")
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
parser.add_argument("--use_tpu",
default = False,
type = bool,
help = "Whether to use TPU or GPU/CPU.")
parser.add_argument("--tpu_name",
default = None,
type = str,
help = "The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
parser.add_argument("--tpu_zone",
default = None,
type = str,
help = "[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
parser.add_argument("--gcp_project",
default = None,
type = str,
help = "[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
parser.add_argument("--master",
default = None,
type = str,
help = "[Optional] TensorFlow master URL.")
parser.add_argument("--num_tpu_cores",
default = 8,
type = int,
help = "Only used if `use_tpu` is True. Total number of TPU cores to use.")
### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
args = parser.parse_args()