From a3a604cefbd96d6f23366b6c9c87c3e98889461c Mon Sep 17 00:00:00 2001 From: jeonsworld <37530102+jeonsworld@users.noreply.github.com> Date: Mon, 10 Jun 2019 12:17:23 +0900 Subject: [PATCH] Update pregenerate_training_data.py apply Whole Word Masking technique. referred to [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) --- .../pregenerate_training_data.py | 82 +++++++++++++------ 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/examples/lm_finetuning/pregenerate_training_data.py b/examples/lm_finetuning/pregenerate_training_data.py index e6c3598a9fe..6cb89544657 100644 --- a/examples/lm_finetuning/pregenerate_training_data.py +++ b/examples/lm_finetuning/pregenerate_training_data.py @@ -4,11 +4,11 @@ from tqdm import tqdm, trange from tempfile import TemporaryDirectory import shelve -from random import random, randrange, randint, shuffle, choice, sample +from random import random, randrange, randint, shuffle, choice from pytorch_pretrained_bert.tokenization import BertTokenizer import numpy as np import json - +import collections class DocumentDatabase: def __init__(self, reduce_memory=False): @@ -98,42 +98,77 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): else: trunc_tokens.pop() +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", + ["index", "label"]) -def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list): +def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list): """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but with several refactors to clean it up and remove a lot of unnecessary variables.""" cand_indices = [] for (i, token) in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue - cand_indices.append(i) + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (whole_word_mask and len(cand_indices) >= 1 and token.startswith("##")): + cand_indices[-1].append(i) + else: + cand_indices.append([i]) num_to_mask = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) shuffle(cand_indices) - mask_indices = sorted(sample(cand_indices, num_to_mask)) - masked_token_labels = [] - for index in mask_indices: - # 80% of the time, replace with [MASK] - if random() < 0.8: - masked_token = "[MASK]" - else: - # 10% of the time, keep original - if random() < 0.5: - masked_token = tokens[index] - # 10% of the time, replace with random word + masked_lms = [] + covered_indexes = set() + for index_set in cand_indices: + if len(masked_lms) >= num_to_mask: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_mask: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if random() < 0.8: + masked_token = "[MASK]" else: - masked_token = choice(vocab_list) - masked_token_labels.append(tokens[index]) - # Once we've saved the true label for that token, we can overwrite it with the masked version - tokens[index] = masked_token + # 10% of the time, keep original + if random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = choice(vocab_list) + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + tokens[index] = masked_token + + assert len(masked_lms) <= num_to_mask + masked_lms = sorted(masked_lms, key=lambda x: x.index) + mask_indices = [p.index for p in masked_lms] + masked_token_labels = [p.label for p in masked_lms] return tokens, mask_indices, masked_token_labels def create_instances_from_document( doc_database, doc_idx, max_seq_length, short_seq_prob, - masked_lm_prob, max_predictions_per_seq, vocab_list): + masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list): """This code is mostly a duplicate of the equivalent function from Google BERT's repo. However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function. Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence @@ -213,7 +248,7 @@ def create_instances_from_document( segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions( - tokens, masked_lm_prob, max_predictions_per_seq, vocab_list) + tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list) instance = { "tokens": tokens, @@ -237,7 +272,8 @@ def main(): choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", "bert-base-multilingual", "bert-base-chinese"]) parser.add_argument("--do_lower_case", action="store_true") - + parser.add_argument("--do_whole_word_mask", action="store_true", + help="Whether to use whole word masking rather than per-WordPiece masking.") parser.add_argument("--reduce_memory", action="store_true", help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") @@ -284,7 +320,7 @@ def main(): doc_instances = create_instances_from_document( docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, - vocab_list=vocab_list) + whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list) doc_instances = [json.dumps(instance) for instance in doc_instances] for instance in doc_instances: epoch_file.write(instance + '\n')