mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
145 lines
6.5 KiB
Python
145 lines
6.5 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""BERT finetuning runner."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
# import csv
|
|
# import os
|
|
# import modeling_pytorch
|
|
# import optimization
|
|
# import tokenization
|
|
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
## Required parameters
|
|
parser.add_argument("--data_dir",
|
|
default = None,
|
|
type = str,
|
|
required = True,
|
|
help = "The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
|
parser.add_argument("--bert_config_file",
|
|
default = None,
|
|
type = str,
|
|
required = True,
|
|
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
|
"This specifies the model architecture.")
|
|
parser.add_argument("--task_name",
|
|
default = None,
|
|
type = str,
|
|
required = True,
|
|
help = "The name of the task to train.")
|
|
parser.add_argument("--vocab_file",
|
|
default = None,
|
|
type = str,
|
|
required = True,
|
|
help = "The vocabulary file that the BERT model was trained on.")
|
|
parser.add_argument("--output_dir",
|
|
default = None,
|
|
type = str,
|
|
required = True,
|
|
help = "The output directory where the model checkpoints will be written.")
|
|
|
|
## Other parameters
|
|
parser.add_argument("--init_checkpoint",
|
|
default = None,
|
|
type = str,
|
|
help = "Initial checkpoint (usually from a pre-trained BERT model).")
|
|
parser.add_argument("--do_lower_case",
|
|
default = True,
|
|
type = bool,
|
|
help = "Whether to lower case the input text. Should be True for uncased models and False for cased models.")
|
|
parser.add_argument("--max_seq_length",
|
|
default = 128,
|
|
type = int,
|
|
help = "The maximum total input sequence length after WordPiece tokenization. \n"
|
|
"Sequences longer than this will be truncated, and sequences shorter \n"
|
|
"than this will be padded.")
|
|
parser.add_argument("--do_train",
|
|
default = False,
|
|
type = bool,
|
|
help = "Whether to run training.")
|
|
parser.add_argument("--do_eval",
|
|
default = False,
|
|
type = bool,
|
|
help = "Whether to run eval on the dev set.")
|
|
parser.add_argument("--train_batch_size",
|
|
default = 32,
|
|
type = int,
|
|
help = "Total batch size for training.")
|
|
parser.add_argument("--eval_batch_size",
|
|
default = 8,
|
|
type = int,
|
|
help = "Total batch size for eval.")
|
|
parser.add_argument("--learning_rate",
|
|
default = 5e-5,
|
|
type = float,
|
|
help = "The initial learning rate for Adam.")
|
|
parser.add_argument("--num_train_epochs",
|
|
default = 3.0,
|
|
type = float,
|
|
help = "Total number of training epochs to perform.")
|
|
parser.add_argument("--warmup_proportion",
|
|
default = 0.1,
|
|
type = float,
|
|
help = "Proportion of training to perform linear learning rate warmup for. "
|
|
"E.g., 0.1 = 10%% of training.")
|
|
parser.add_argument("--save_checkpoints_steps",
|
|
default = 1000,
|
|
type = int,
|
|
help = "How often to save the model checkpoint.")
|
|
parser.add_argument("--iterations_per_loop",
|
|
default = 1000,
|
|
type = int,
|
|
help = "How many steps to make in each estimator call.")
|
|
|
|
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
|
|
parser.add_argument("--use_tpu",
|
|
default = False,
|
|
type = bool,
|
|
help = "Whether to use TPU or GPU/CPU.")
|
|
parser.add_argument("--tpu_name",
|
|
default = None,
|
|
type = str,
|
|
help = "The Cloud TPU to use for training. This should be either the name "
|
|
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
|
"url.")
|
|
parser.add_argument("--tpu_zone",
|
|
default = None,
|
|
type = str,
|
|
help = "[Optional] GCE zone where the Cloud TPU is located in. If not "
|
|
"specified, we will attempt to automatically detect the GCE project from "
|
|
"metadata.")
|
|
parser.add_argument("--gcp_project",
|
|
default = None,
|
|
type = str,
|
|
help = "[Optional] Project name for the Cloud TPU-enabled project. If not "
|
|
"specified, we will attempt to automatically detect the GCE project from "
|
|
"metadata.")
|
|
parser.add_argument("--master",
|
|
default = None,
|
|
type = str,
|
|
help = "[Optional] TensorFlow master URL.")
|
|
parser.add_argument("--num_tpu_cores",
|
|
default = 8,
|
|
type = int,
|
|
help = "Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
|
### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
|
|
|
|
args = parser.parse_args() |