Update pregenerate_training_data.py

apply Whole Word Masking technique.
referred to [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py)
This commit is contained in:
jeonsworld 2019-06-10 12:17:23 +09:00 committed by GitHub
parent ee0308f79d
commit a3a604cefb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,11 +4,11 @@ from tqdm import tqdm, trange
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import shelve import shelve
from random import random, randrange, randint, shuffle, choice, sample from random import random, randrange, randint, shuffle, choice
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np import numpy as np
import json import json
import collections
class DocumentDatabase: class DocumentDatabase:
def __init__(self, reduce_memory=False): def __init__(self, reduce_memory=False):
@ -98,42 +98,77 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
else: else:
trunc_tokens.pop() trunc_tokens.pop()
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list): def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list):
"""Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
with several refactors to clean it up and remove a lot of unnecessary variables.""" with several refactors to clean it up and remove a lot of unnecessary variables."""
cand_indices = [] cand_indices = []
for (i, token) in enumerate(tokens): for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]": if token == "[CLS]" or token == "[SEP]":
continue continue
cand_indices.append(i) # Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (whole_word_mask and len(cand_indices) >= 1 and token.startswith("##")):
cand_indices[-1].append(i)
else:
cand_indices.append([i])
num_to_mask = min(max_predictions_per_seq, num_to_mask = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob)))) max(1, int(round(len(tokens) * masked_lm_prob))))
shuffle(cand_indices) shuffle(cand_indices)
mask_indices = sorted(sample(cand_indices, num_to_mask)) masked_lms = []
masked_token_labels = [] covered_indexes = set()
for index in mask_indices: for index_set in cand_indices:
# 80% of the time, replace with [MASK] if len(masked_lms) >= num_to_mask:
if random() < 0.8: break
masked_token = "[MASK]" # If adding a whole-word mask would exceed the maximum number of
else: # predictions, then just skip this candidate.
# 10% of the time, keep original if len(masked_lms) + len(index_set) > num_to_mask:
if random() < 0.5: continue
masked_token = tokens[index] is_any_index_covered = False
# 10% of the time, replace with random word for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if random() < 0.8:
masked_token = "[MASK]"
else: else:
masked_token = choice(vocab_list) # 10% of the time, keep original
masked_token_labels.append(tokens[index]) if random() < 0.5:
# Once we've saved the true label for that token, we can overwrite it with the masked version masked_token = tokens[index]
tokens[index] = masked_token # 10% of the time, replace with random word
else:
masked_token = choice(vocab_list)
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
tokens[index] = masked_token
assert len(masked_lms) <= num_to_mask
masked_lms = sorted(masked_lms, key=lambda x: x.index)
mask_indices = [p.index for p in masked_lms]
masked_token_labels = [p.label for p in masked_lms]
return tokens, mask_indices, masked_token_labels return tokens, mask_indices, masked_token_labels
def create_instances_from_document( def create_instances_from_document(
doc_database, doc_idx, max_seq_length, short_seq_prob, doc_database, doc_idx, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_list): masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list):
"""This code is mostly a duplicate of the equivalent function from Google BERT's repo. """This code is mostly a duplicate of the equivalent function from Google BERT's repo.
However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function. However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function.
Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence
@ -213,7 +248,7 @@ def create_instances_from_document(
segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)]
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions( tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_list) tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list)
instance = { instance = {
"tokens": tokens, "tokens": tokens,
@ -237,7 +272,8 @@ def main():
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"]) "bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true") parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--do_whole_word_mask", action="store_true",
help="Whether to use whole word masking rather than per-WordPiece masking.")
parser.add_argument("--reduce_memory", action="store_true", parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
@ -284,7 +320,7 @@ def main():
doc_instances = create_instances_from_document( doc_instances = create_instances_from_document(
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
vocab_list=vocab_list) whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances] doc_instances = [json.dumps(instance) for instance in doc_instances]
for instance in doc_instances: for instance in doc_instances:
epoch_file.write(instance + '\n') epoch_file.write(instance + '\n')