mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
read parameters from CLI, load model & tokenizer
This commit is contained in:
parent
d889e0b71b
commit
b3261e7ace
@ -30,12 +30,15 @@ Gao, Ming Zhou, and Hsiao-Wuen Hon. “Unified Language Model Pre-Training for
|
||||
Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, Bert2Rnd, BertTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -43,25 +46,60 @@ def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Fine-tune the pretrained model on the corpus. """
|
||||
# Data sampler
|
||||
# Data loader
|
||||
# Training
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def main():
|
||||
raise NotImplementedError
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument("--train_data_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input training data file (a text file).")
|
||||
parser.add_argument("--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
# Optional parameters
|
||||
parser.add_argument("--model_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint for weights initialization.")
|
||||
parser.add_argument("--seed", default=42, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up training device
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer
|
||||
config = config_class.from_pretrained(args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, config=config)
|
||||
model.to(device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
train_dataset = load_and_cache_examples(args, tokenizer)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
def __main__():
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,49 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning seq2seq models for abstractive summarization.
|
||||
|
||||
The finetuning method for abstractive summarization is inspired by [1]. We
|
||||
concatenate the document and summary, mask words of the summary at random and
|
||||
maximizing the likelihood of masked words.
|
||||
|
||||
[1] Dong Li, Nan Yang, Wenhui Wang, Furu Wei, Xiaodong Liu, Yu Wang, Jianfeng
|
||||
Gao, Ming Zhou, and Hsiao-Wuen Hon. “Unified Language Model Pre-Training for
|
||||
Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue
Block a user