mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add example for fine tuning BERT language model (#1)
Adds an example for loading a pre-trained BERT model and fine tune it as a language model (masked tokens & nextSentence) on your target corpus.
This commit is contained in:
parent
786cc41299
commit
a58361f197
674
examples/run_lm_finetuning.py
Normal file
674
examples/run_lm_finetuning.py
Normal file
@ -0,0 +1,674 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import random
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BERTDataset(Dataset):
|
||||
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
|
||||
self.vocab = tokenizer.vocab
|
||||
self.tokenizer = tokenizer
|
||||
self.seq_len = seq_len
|
||||
self.on_memory = on_memory
|
||||
self.corpus_lines = corpus_lines # number of non-empty lines in input corpus
|
||||
self.corpus_path = corpus_path
|
||||
self.encoding = encoding
|
||||
self.current_doc = 0 # to avoid random sentence from same doc
|
||||
|
||||
# for loading samples directly from file
|
||||
self.sample_counter = 0 # used to keep track of full epochs on file
|
||||
self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair
|
||||
|
||||
# for loading samples in memory
|
||||
self.current_random_doc = 0
|
||||
self.num_docs = 0
|
||||
|
||||
# load samples into memory
|
||||
if on_memory:
|
||||
self.all_docs = []
|
||||
doc = []
|
||||
self.corpus_lines = 0
|
||||
with open(corpus_path, "r", encoding=encoding) as f:
|
||||
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
|
||||
line = line.strip()
|
||||
if line == "":
|
||||
self.all_docs.append(doc)
|
||||
doc = []
|
||||
else:
|
||||
doc.append(line)
|
||||
self.corpus_lines = self.corpus_lines + 1
|
||||
# if last row in file is not empty
|
||||
if self.all_docs[-1] != doc:
|
||||
self.all_docs.append(doc)
|
||||
|
||||
self.num_docs = len(self.all_docs)
|
||||
|
||||
# load samples later lazily from disk
|
||||
else:
|
||||
if self.corpus_lines is None:
|
||||
with open(corpus_path, "r", encoding=encoding) as f:
|
||||
self.corpus_lines = 0
|
||||
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
|
||||
if line.strip() == "":
|
||||
self.num_docs += 1
|
||||
else:
|
||||
self.corpus_lines += 1
|
||||
|
||||
# if doc does not end with empty line
|
||||
if line.strip() != "":
|
||||
self.num_docs += 1
|
||||
|
||||
self.file = open(corpus_path, "r", encoding=encoding)
|
||||
self.random_file = open(corpus_path, "r", encoding=encoding)
|
||||
|
||||
def __len__(self):
|
||||
# last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
|
||||
return self.corpus_lines - self.num_docs - 1
|
||||
|
||||
def __getitem__(self, item):
|
||||
cur_id = self.sample_counter
|
||||
self.sample_counter += 1
|
||||
if not self.on_memory:
|
||||
# after one epoch we start again from beginning of file
|
||||
if cur_id != 0 and (cur_id % len(self) == 0):
|
||||
self.file.close()
|
||||
self.file = open(self.corpus_path, "r", encoding=self.encoding)
|
||||
|
||||
t1, t2, is_next_label = self.random_sent(item)
|
||||
|
||||
# tokenize
|
||||
tokens_a = self.tokenizer.tokenize(t1)
|
||||
tokens_b = self.tokenizer.tokenize(t2)
|
||||
|
||||
# combine to one sample
|
||||
cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label)
|
||||
|
||||
# transform sample to features
|
||||
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
|
||||
|
||||
cur_tensors = {"input_ids": torch.tensor(cur_features.input_ids),
|
||||
"input_mask": torch.tensor(cur_features.input_mask),
|
||||
"segment_ids": torch.tensor(cur_features.segment_ids),
|
||||
"lm_label_ids": torch.tensor(cur_features.lm_label_ids),
|
||||
"is_next": torch.tensor(cur_features.is_next)}
|
||||
|
||||
return cur_tensors
|
||||
|
||||
def random_sent(self, index):
|
||||
"""
|
||||
Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
|
||||
from one doc. With 50% the second sentence will be a random one from another doc.
|
||||
:param index: int, index of sample.
|
||||
:return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
|
||||
"""
|
||||
t1, t2 = self.get_corpus_line(index)
|
||||
if random.random() > 0.5:
|
||||
label = 0
|
||||
else:
|
||||
t2 = self.get_random_line()
|
||||
label = 1
|
||||
|
||||
assert len(t1) > 0
|
||||
assert len(t2) > 0
|
||||
return t1, t2, label
|
||||
|
||||
def get_corpus_line(self, item):
|
||||
"""
|
||||
Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
|
||||
:param item: int, index of sample.
|
||||
:return: (str, str), two subsequent sentences from corpus
|
||||
"""
|
||||
t1 = ""
|
||||
t2 = ""
|
||||
assert item < self.corpus_lines
|
||||
if self.on_memory:
|
||||
# get the right doc
|
||||
doc_id = 0
|
||||
doc_start = 0
|
||||
doc_end = len(self.all_docs[doc_id]) - 2
|
||||
while item > doc_end:
|
||||
doc_id += 1
|
||||
doc_start = doc_end + 1
|
||||
doc_end += len(self.all_docs[doc_id]) - 1
|
||||
# get the right line within doc
|
||||
line_in_doc = item - doc_start
|
||||
t1 = self.all_docs[doc_id][line_in_doc]
|
||||
t2 = self.all_docs[doc_id][line_in_doc + 1]
|
||||
# used later to avoid random nextSentence from same doc
|
||||
self.current_doc = doc_id
|
||||
return t1, t2
|
||||
else:
|
||||
if self.line_buffer is None:
|
||||
# read first non-empty line of file
|
||||
while t1 == "" :
|
||||
t1 = self.file.__next__().strip()
|
||||
t2 = self.file.__next__().strip()
|
||||
else:
|
||||
# use t2 from previous iteration as new t1
|
||||
t1 = self.line_buffer
|
||||
t2 = self.file.__next__().strip()
|
||||
# skip empty rows that are used for separating documents and keep track of current doc id
|
||||
while t2 == "" or t1 == "":
|
||||
t1 = self.file.__next__().strip()
|
||||
t2 = self.file.__next__().strip()
|
||||
self.current_doc = self.current_doc+1
|
||||
self.line_buffer = t2
|
||||
|
||||
assert t1 != ""
|
||||
assert t2 != ""
|
||||
return t1, t2
|
||||
|
||||
def get_random_line(self):
|
||||
"""
|
||||
Get random line from another document for nextSentence task.
|
||||
:return: str, content of one line
|
||||
"""
|
||||
# Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
|
||||
# corpora. However, just to be careful, we try to make sure that
|
||||
# the random document is not the same as the document we're processing.
|
||||
for _ in range(10):
|
||||
if self.on_memory:
|
||||
rand_doc_idx = random.randint(0, len(self.all_docs)-1)
|
||||
rand_doc = self.all_docs[rand_doc_idx]
|
||||
line = rand_doc[random.randrange(len(rand_doc))]
|
||||
else:
|
||||
rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
|
||||
#pick random line
|
||||
for _ in range(rand_index):
|
||||
line = self.get_next_line()
|
||||
#check if our picked random line is really from another doc like we want it to be
|
||||
if self.current_random_doc != self.current_doc:
|
||||
break
|
||||
return line
|
||||
|
||||
def get_next_line(self):
|
||||
""" Gets next line of random_file and starts over when reaching end of file"""
|
||||
try:
|
||||
line = self.random_file.__next__().strip()
|
||||
#keep track of which document we are currently looking at to later avoid having the same doc as t1
|
||||
if line == "":
|
||||
self.current_random_doc = self.current_random_doc + 1
|
||||
line = self.random_file.__next__().strip()
|
||||
except StopIteration:
|
||||
self.random_file.close()
|
||||
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
|
||||
line = self.random_file.__next__().strip()
|
||||
return line
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for the language model."""
|
||||
|
||||
def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
|
||||
"""Constructs a InputExample.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
tokens_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
tokens_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
self.guid = guid
|
||||
self.tokens_a = tokens_a
|
||||
self.tokens_b = tokens_b
|
||||
self.is_next = is_next # nextSentence
|
||||
self.lm_labels = lm_labels # masked words for language model
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self, input_ids, input_mask, segment_ids, is_next, lm_label_ids):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.is_next = is_next
|
||||
self.lm_label_ids = lm_label_ids
|
||||
|
||||
|
||||
def random_word(tokens, tokenizer):
|
||||
"""
|
||||
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
|
||||
:param tokens: list of str, tokenized sentence.
|
||||
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
|
||||
:return: (list of str, list of int), masked tokens and related labels for LM prediction
|
||||
"""
|
||||
output_label = []
|
||||
|
||||
for i, token in enumerate(tokens):
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
prob /= 0.15
|
||||
|
||||
# 80% randomly change token to mask token
|
||||
if prob < 0.8:
|
||||
tokens[i] = "[MASK]"
|
||||
|
||||
# 10% randomly change token to random token
|
||||
elif prob < 0.9:
|
||||
tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
|
||||
|
||||
# -> rest 10% randomly keep current token
|
||||
|
||||
# append current token to output (we will predict these later)
|
||||
try:
|
||||
output_label.append(tokenizer.vocab[token])
|
||||
except KeyError:
|
||||
# For unknown words (should not occur with BPE vocab)
|
||||
output_label.append(tokenizer.vocab["[UNK]"])
|
||||
else:
|
||||
# no masking token (will be ignored by loss function later)
|
||||
output_label.append(-1)
|
||||
|
||||
return tokens, output_label
|
||||
|
||||
|
||||
def convert_example_to_features(example, max_seq_length, tokenizer):
|
||||
"""
|
||||
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
|
||||
IDs, LM labels, input_mask, CLS and SEP tokens etc.
|
||||
:param example: InputExample, containing sentence input as strings and is_next label
|
||||
:param max_seq_length: int, maximum length of sequence.
|
||||
:param tokenizer: Tokenizer
|
||||
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
|
||||
"""
|
||||
tokens_a = example.tokens_a
|
||||
tokens_b = example.tokens_b
|
||||
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
||||
# length is less than the specified length.
|
||||
# Account for [CLS], [SEP], [SEP] with "- 3"
|
||||
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
||||
|
||||
t1_random, t1_label = random_word(tokens_a, tokenizer)
|
||||
t2_random, t2_label = random_word(tokens_b, tokenizer)
|
||||
# concatenate lm labels and account for CLS, SEP, SEP
|
||||
lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
|
||||
|
||||
# The convention in BERT is:
|
||||
# (a) For sequence pairs:
|
||||
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
||||
# (b) For single sequences:
|
||||
# tokens: [CLS] the dog is hairy . [SEP]
|
||||
# type_ids: 0 0 0 0 0 0 0
|
||||
#
|
||||
# Where "type_ids" are used to indicate whether this is the first
|
||||
# sequence or the second sequence. The embedding vectors for `type=0` and
|
||||
# `type=1` were learned during pre-training and are added to the wordpiece
|
||||
# embedding vector (and position vector). This is not *strictly* necessary
|
||||
# since the [SEP] token unambigiously separates the sequences, but it makes
|
||||
# it easier for the model to learn the concept of sequences.
|
||||
#
|
||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||
# used as as the "sentence vector". Note that this only makes sense because
|
||||
# the entire model is fine-tuned.
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(tokens_b) > 0
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
lm_label_ids.append(-1)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
assert len(lm_label_ids) == max_seq_length
|
||||
|
||||
if example.guid < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("guid: %s" % (example.guid))
|
||||
logger.info("tokens: %s" % " ".join(
|
||||
[str(x) for x in tokens]))
|
||||
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
logger.info(
|
||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
logger.info("LM label: %s " % (lm_label_ids))
|
||||
logger.info("Is next sentence label: %s " % (example.is_next))
|
||||
|
||||
features = InputFeatures(input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
lm_label_ids=lm_label_ids,
|
||||
is_next=example.is_next)
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--train_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input train corpus.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||
parser.add_argument("--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||
"than this will be padded.")
|
||||
parser.add_argument("--do_train",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--train_batch_size",
|
||||
default=32,
|
||||
type=int,
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument("--eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Total batch size for eval.")
|
||||
parser.add_argument("--learning_rate",
|
||||
default=3e-5,
|
||||
type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--num_train_epochs",
|
||||
default=3.0,
|
||||
type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--warmup_proportion",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument("--no_cuda",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument("--on_memory",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to load train samples into memory or use disk")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--gradient_accumulation_steps',
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumualte before performing a backward/update pass.")
|
||||
parser.add_argument('--optimize_on_cpu',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to perform optimization and keep the optimizer averages on CPU")
|
||||
parser.add_argument('--fp16',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--loss_scale',
|
||||
type=float, default=128,
|
||||
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
if args.fp16:
|
||||
logger.info("16-bits training currently not supported in distributed training")
|
||||
args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
|
||||
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
|
||||
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
if not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
|
||||
|
||||
#train_examples = None
|
||||
num_train_steps = None
|
||||
if args.do_train:
|
||||
print("Loading Train Dataset", args.train_file)
|
||||
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
|
||||
corpus_lines=None, on_memory=args.on_memory)
|
||||
num_train_steps = int(
|
||||
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
||||
|
||||
# Prepare model
|
||||
model = BertForPreTraining.from_pretrained(args.bert_model)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank)
|
||||
elif n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Prepare optimizer
|
||||
if args.fp16:
|
||||
param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \
|
||||
for n, param in model.named_parameters()]
|
||||
elif args.optimize_on_cpu:
|
||||
param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \
|
||||
for n, param in model.named_parameters()]
|
||||
else:
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
||||
]
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_steps)
|
||||
|
||||
global_step = 0
|
||||
if args.do_train:
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Num steps = %d", num_train_steps)
|
||||
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
else:
|
||||
#TODO: check if this works with current data generator from disk that relies on file.__next__
|
||||
# (it doesn't return item back by index)
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
model.train()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
tr_loss = 0
|
||||
nb_tr_examples, nb_tr_steps = 0, 0
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
batch = tuple(t.to(device) for t in batch.values())
|
||||
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
|
||||
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.fp16 and args.loss_scale != 1.0:
|
||||
# rescale loss for fp16 training
|
||||
# see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
|
||||
loss = loss * args.loss_scale
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16 or args.optimize_on_cpu:
|
||||
if args.fp16 and args.loss_scale != 1.0:
|
||||
# scale down gradients for fp16 training
|
||||
for param in model.parameters():
|
||||
param.grad.data = param.grad.data / args.loss_scale
|
||||
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
|
||||
if is_nan:
|
||||
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
|
||||
args.loss_scale = args.loss_scale / 2
|
||||
model.zero_grad()
|
||||
continue
|
||||
optimizer.step()
|
||||
copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
|
||||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
if n_gpu > 1:
|
||||
torch.save(model.module.bert.state_dict(), output_model_file)
|
||||
else:
|
||||
torch.save(model.bert.state_dict(), output_model_file)
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
|
||||
def accuracy(out, labels):
|
||||
outputs = np.argmax(out, axis=1)
|
||||
return np.sum(outputs == labels)
|
||||
|
||||
|
||||
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
|
||||
""" Utility function for optimize_on_cpu and 16-bits training.
|
||||
Copy the parameters optimized on CPU/RAM back to the model on GPU
|
||||
"""
|
||||
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
||||
if name_opti != name_model:
|
||||
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
||||
raise ValueError
|
||||
param_model.data.copy_(param_opti.data)
|
||||
|
||||
|
||||
def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
|
||||
""" Utility function for optimize_on_cpu and 16-bits training.
|
||||
Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
|
||||
"""
|
||||
is_nan = False
|
||||
for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
|
||||
if name_opti != name_model:
|
||||
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
|
||||
raise ValueError
|
||||
if param_model.grad is not None:
|
||||
if test_nan and torch.isnan(param_model.grad).sum() > 0:
|
||||
is_nan = True
|
||||
if param_opti.grad is None:
|
||||
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
|
||||
param_opti.grad.data.copy_(param_model.grad.data)
|
||||
else:
|
||||
param_opti.grad = None
|
||||
return is_nan
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user