# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """BERT finetuning runner.""" from __future__ import absolute_import, division, print_function, unicode_literals import argparse import logging import os import random from io import open import numpy as np import torch from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from torch.utils.data import Dataset import random logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) class BERTDataset(Dataset): def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): self.vocab = tokenizer.vocab self.tokenizer = tokenizer self.seq_len = seq_len self.on_memory = on_memory self.corpus_lines = corpus_lines # number of non-empty lines in input corpus self.corpus_path = corpus_path self.encoding = encoding self.current_doc = 0 # to avoid random sentence from same doc # for loading samples directly from file self.sample_counter = 0 # used to keep track of full epochs on file self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair # for loading samples in memory self.current_random_doc = 0 self.num_docs = 0 self.sample_to_doc = [] # map sample index to doc and line # load samples into memory if on_memory: self.all_docs = [] doc = [] self.corpus_lines = 0 with open(corpus_path, "r", encoding=encoding) as f: for line in tqdm(f, desc="Loading Dataset", total=corpus_lines): line = line.strip() if line == "": self.all_docs.append(doc) doc = [] #remove last added sample because there won't be a subsequent line anymore in the doc self.sample_to_doc.pop() else: #store as one sample sample = {"doc_id": len(self.all_docs), "line": len(doc)} self.sample_to_doc.append(sample) doc.append(line) self.corpus_lines = self.corpus_lines + 1 # if last row in file is not empty if self.all_docs[-1] != doc: self.all_docs.append(doc) self.sample_to_doc.pop() self.num_docs = len(self.all_docs) # load samples later lazily from disk else: if self.corpus_lines is None: with open(corpus_path, "r", encoding=encoding) as f: self.corpus_lines = 0 for line in tqdm(f, desc="Loading Dataset", total=corpus_lines): if line.strip() == "": self.num_docs += 1 else: self.corpus_lines += 1 # if doc does not end with empty line if line.strip() != "": self.num_docs += 1 self.file = open(corpus_path, "r", encoding=encoding) self.random_file = open(corpus_path, "r", encoding=encoding) def __len__(self): # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0. return self.corpus_lines - self.num_docs - 1 def __getitem__(self, item): cur_id = self.sample_counter self.sample_counter += 1 if not self.on_memory: # after one epoch we start again from beginning of file if cur_id != 0 and (cur_id % len(self) == 0): self.file.close() self.file = open(self.corpus_path, "r", encoding=self.encoding) t1, t2, is_next_label = self.random_sent(item) # tokenize tokens_a = self.tokenizer.tokenize(t1) tokens_b = self.tokenizer.tokenize(t2) # combine to one sample cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label) # transform sample to features cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer) cur_tensors = (torch.tensor(cur_features.input_ids), torch.tensor(cur_features.input_mask), torch.tensor(cur_features.segment_ids), torch.tensor(cur_features.lm_label_ids), torch.tensor(cur_features.is_next)) return cur_tensors def random_sent(self, index): """ Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences from one doc. With 50% the second sentence will be a random one from another doc. :param index: int, index of sample. :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label """ t1, t2 = self.get_corpus_line(index) if random.random() > 0.5: label = 0 else: t2 = self.get_random_line() label = 1 assert len(t1) > 0 assert len(t2) > 0 return t1, t2, label def get_corpus_line(self, item): """ Get one sample from corpus consisting of a pair of two subsequent lines from the same doc. :param item: int, index of sample. :return: (str, str), two subsequent sentences from corpus """ t1 = "" t2 = "" assert item < self.corpus_lines if self.on_memory: sample = self.sample_to_doc[item] t1 = self.all_docs[sample["doc_id"]][sample["line"]] t2 = self.all_docs[sample["doc_id"]][sample["line"]+1] # used later to avoid random nextSentence from same doc self.current_doc = sample["doc_id"] return t1, t2 else: if self.line_buffer is None: # read first non-empty line of file while t1 == "" : t1 = next(self.file).strip() t2 = next(self.file).strip() else: # use t2 from previous iteration as new t1 t1 = self.line_buffer t2 = next(self.file).strip() # skip empty rows that are used for separating documents and keep track of current doc id while t2 == "" or t1 == "": t1 = next(self.file).strip() t2 = next(self.file).strip() self.current_doc = self.current_doc+1 self.line_buffer = t2 assert t1 != "" assert t2 != "" return t1, t2 def get_random_line(self): """ Get random line from another document for nextSentence task. :return: str, content of one line """ # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large # corpora. However, just to be careful, we try to make sure that # the random document is not the same as the document we're processing. for _ in range(10): if self.on_memory: rand_doc_idx = random.randint(0, len(self.all_docs)-1) rand_doc = self.all_docs[rand_doc_idx] line = rand_doc[random.randrange(len(rand_doc))] else: rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000) #pick random line for _ in range(rand_index): line = self.get_next_line() #check if our picked random line is really from another doc like we want it to be if self.current_random_doc != self.current_doc: break return line def get_next_line(self): """ Gets next line of random_file and starts over when reaching end of file""" try: line = next(self.random_file).strip() #keep track of which document we are currently looking at to later avoid having the same doc as t1 if line == "": self.current_random_doc = self.current_random_doc + 1 line = next(self.random_file).strip() except StopIteration: self.random_file.close() self.random_file = open(self.corpus_path, "r", encoding=self.encoding) line = next(self.random_file).strip() return line class InputExample(object): """A single training/test example for the language model.""" def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None): """Constructs a InputExample. Args: guid: Unique id for the example. tokens_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. tokens_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid self.tokens_a = tokens_a self.tokens_b = tokens_b self.is_next = is_next # nextSentence self.lm_labels = lm_labels # masked words for language model class InputFeatures(object): """A single set of features of data.""" def __init__(self, input_ids, input_mask, segment_ids, is_next, lm_label_ids): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.is_next = is_next self.lm_label_ids = lm_label_ids def random_word(tokens, tokenizer): """ Masking some random tokens for Language Model task with probabilities as in the original BERT paper. :param tokens: list of str, tokenized sentence. :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here) :return: (list of str, list of int), masked tokens and related labels for LM prediction """ output_label = [] for i, token in enumerate(tokens): prob = random.random() # mask token with 15% probability if prob < 0.15: prob /= 0.15 # 80% randomly change token to mask token if prob < 0.8: tokens[i] = "[MASK]" # 10% randomly change token to random token elif prob < 0.9: tokens[i] = random.choice(list(tokenizer.vocab.items()))[0] # -> rest 10% randomly keep current token # append current token to output (we will predict these later) try: output_label.append(tokenizer.vocab[token]) except KeyError: # For unknown words (should not occur with BPE vocab) output_label.append(tokenizer.vocab["[UNK]"]) logger.warning("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token)) else: # no masking token (will be ignored by loss function later) output_label.append(-1) return tokens, output_label def convert_example_to_features(example, max_seq_length, tokenizer): """ Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with IDs, LM labels, input_mask, CLS and SEP tokens etc. :param example: InputExample, containing sentence input as strings and is_next label :param max_seq_length: int, maximum length of sequence. :param tokenizer: Tokenizer :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training) """ tokens_a = example.tokens_a tokens_b = example.tokens_b # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. # Account for [CLS], [SEP], [SEP] with "- 3" _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) tokens_a, t1_label = random_word(tokens_a, tokenizer) tokens_b, t2_label = random_word(tokens_b, tokenizer) # concatenate lm labels and account for CLS, SEP, SEP lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1]) # The convention in BERT is: # (a) For sequence pairs: # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 # (b) For single sequences: # tokens: [CLS] the dog is hairy . [SEP] # type_ids: 0 0 0 0 0 0 0 # # Where "type_ids" are used to indicate whether this is the first # sequence or the second sequence. The embedding vectors for `type=0` and # `type=1` were learned during pre-training and are added to the wordpiece # embedding vector (and position vector). This is not *strictly* necessary # since the [SEP] token unambigiously separates the sequences, but it makes # it easier for the model to learn the concept of sequences. # # For classification tasks, the first vector (corresponding to [CLS]) is # used as as the "sentence vector". Note that this only makes sense because # the entire model is fine-tuned. tokens = [] segment_ids = [] tokens.append("[CLS]") segment_ids.append(0) for token in tokens_a: tokens.append(token) segment_ids.append(0) tokens.append("[SEP]") segment_ids.append(0) assert len(tokens_b) > 0 for token in tokens_b: tokens.append(token) segment_ids.append(1) tokens.append("[SEP]") segment_ids.append(1) input_ids = tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) lm_label_ids.append(-1) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(lm_label_ids) == max_seq_length if example.guid < 5: logger.info("*** Example ***") logger.info("guid: %s" % (example.guid)) logger.info("tokens: %s" % " ".join( [str(x) for x in tokens])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info( "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) logger.info("LM label: %s " % (lm_label_ids)) logger.info("Is next sentence label: %s " % (example.is_next)) features = InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, lm_label_ids=lm_label_ids, is_next=example.is_next) return features def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument("--train_file", default=None, type=str, required=True, help="The input train corpus.") parser.add_argument("--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints will be written.") ## Other parameters parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--on_memory", action='store_true', help="Whether to load train samples into memory or use disk") parser.add_argument("--do_lower_case", action='store_true', help="Whether to lower case the input text. True for uncased models, False for cased models.") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumualte before performing a backward/update pass.") parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--loss_scale', type = float, default = 0, help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") args = parser.parse_args() if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) #train_examples = None num_train_optimization_steps = None if args.do_train: print("Loading Train Dataset", args.train_file) train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory) num_train_optimization_steps = int( len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs if args.local_rank != -1: num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() # Prepare model model = BertForPreTraining.from_pretrained(args.bert_model) if args.fp16: model.half() model.to(device) if args.local_rank != -1: try: from apex.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] if args.fp16: try: from apex.optimizers import FP16_Optimizer from apex.optimizers import FusedAdam except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_optimization_steps) global_step = 0 if args.do_train: logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: #TODO: check if this works with current data generator from disk that relies on next(file) # (it doesn't return item back by index) train_sampler = DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) model.train() for _ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) else: loss.backward() tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: # modify learning rate with special warm up BERT uses # if args.fp16 is False, BertAdam is used that handles this automatically lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model logger.info("** ** * Saving fine - tuned model ** ** * ") model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") if args.do_train: torch.save(model_to_save.state_dict(), output_model_file) def _truncate_seq_pair(tokens_a, tokens_b, max_length): """Truncates a sequence pair in place to the maximum length.""" # This is a simple heuristic which will always truncate the longer sequence # one token at a time. This makes more sense than truncating an equal percent # of tokens from each, since if one sequence is very short then each token # that's truncated likely contains more information than a longer sequence. while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break if len(tokens_a) > len(tokens_b): tokens_a.pop() else: tokens_b.pop() def accuracy(out, labels): outputs = np.argmax(out, axis=1) return np.sum(outputs == labels) if __name__ == "__main__": main()