mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 19:18:24 +06:00
Update pregenerate_training_data.py
apply Whole Word Masking technique. referred to [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py)
This commit is contained in:
parent
ee0308f79d
commit
a3a604cefb
@ -4,11 +4,11 @@ from tqdm import tqdm, trange
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
import shelve
|
import shelve
|
||||||
|
|
||||||
from random import random, randrange, randint, shuffle, choice, sample
|
from random import random, randrange, randint, shuffle, choice
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
|
import collections
|
||||||
|
|
||||||
class DocumentDatabase:
|
class DocumentDatabase:
|
||||||
def __init__(self, reduce_memory=False):
|
def __init__(self, reduce_memory=False):
|
||||||
@ -98,42 +98,77 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
|||||||
else:
|
else:
|
||||||
trunc_tokens.pop()
|
trunc_tokens.pop()
|
||||||
|
|
||||||
|
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
||||||
|
["index", "label"])
|
||||||
|
|
||||||
def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list):
|
def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list):
|
||||||
"""Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
|
"""Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
|
||||||
with several refactors to clean it up and remove a lot of unnecessary variables."""
|
with several refactors to clean it up and remove a lot of unnecessary variables."""
|
||||||
cand_indices = []
|
cand_indices = []
|
||||||
for (i, token) in enumerate(tokens):
|
for (i, token) in enumerate(tokens):
|
||||||
if token == "[CLS]" or token == "[SEP]":
|
if token == "[CLS]" or token == "[SEP]":
|
||||||
continue
|
continue
|
||||||
cand_indices.append(i)
|
# Whole Word Masking means that if we mask all of the wordpieces
|
||||||
|
# corresponding to an original word. When a word has been split into
|
||||||
|
# WordPieces, the first token does not have any marker and any subsequence
|
||||||
|
# tokens are prefixed with ##. So whenever we see the ## token, we
|
||||||
|
# append it to the previous set of word indexes.
|
||||||
|
#
|
||||||
|
# Note that Whole Word Masking does *not* change the training code
|
||||||
|
# at all -- we still predict each WordPiece independently, softmaxed
|
||||||
|
# over the entire vocabulary.
|
||||||
|
if (whole_word_mask and len(cand_indices) >= 1 and token.startswith("##")):
|
||||||
|
cand_indices[-1].append(i)
|
||||||
|
else:
|
||||||
|
cand_indices.append([i])
|
||||||
|
|
||||||
num_to_mask = min(max_predictions_per_seq,
|
num_to_mask = min(max_predictions_per_seq,
|
||||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||||
shuffle(cand_indices)
|
shuffle(cand_indices)
|
||||||
mask_indices = sorted(sample(cand_indices, num_to_mask))
|
masked_lms = []
|
||||||
masked_token_labels = []
|
covered_indexes = set()
|
||||||
for index in mask_indices:
|
for index_set in cand_indices:
|
||||||
# 80% of the time, replace with [MASK]
|
if len(masked_lms) >= num_to_mask:
|
||||||
if random() < 0.8:
|
break
|
||||||
masked_token = "[MASK]"
|
# If adding a whole-word mask would exceed the maximum number of
|
||||||
else:
|
# predictions, then just skip this candidate.
|
||||||
# 10% of the time, keep original
|
if len(masked_lms) + len(index_set) > num_to_mask:
|
||||||
if random() < 0.5:
|
continue
|
||||||
masked_token = tokens[index]
|
is_any_index_covered = False
|
||||||
# 10% of the time, replace with random word
|
for index in index_set:
|
||||||
|
if index in covered_indexes:
|
||||||
|
is_any_index_covered = True
|
||||||
|
break
|
||||||
|
if is_any_index_covered:
|
||||||
|
continue
|
||||||
|
for index in index_set:
|
||||||
|
covered_indexes.add(index)
|
||||||
|
|
||||||
|
masked_token = None
|
||||||
|
# 80% of the time, replace with [MASK]
|
||||||
|
if random() < 0.8:
|
||||||
|
masked_token = "[MASK]"
|
||||||
else:
|
else:
|
||||||
masked_token = choice(vocab_list)
|
# 10% of the time, keep original
|
||||||
masked_token_labels.append(tokens[index])
|
if random() < 0.5:
|
||||||
# Once we've saved the true label for that token, we can overwrite it with the masked version
|
masked_token = tokens[index]
|
||||||
tokens[index] = masked_token
|
# 10% of the time, replace with random word
|
||||||
|
else:
|
||||||
|
masked_token = choice(vocab_list)
|
||||||
|
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||||||
|
tokens[index] = masked_token
|
||||||
|
|
||||||
|
assert len(masked_lms) <= num_to_mask
|
||||||
|
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||||
|
mask_indices = [p.index for p in masked_lms]
|
||||||
|
masked_token_labels = [p.label for p in masked_lms]
|
||||||
|
|
||||||
return tokens, mask_indices, masked_token_labels
|
return tokens, mask_indices, masked_token_labels
|
||||||
|
|
||||||
|
|
||||||
def create_instances_from_document(
|
def create_instances_from_document(
|
||||||
doc_database, doc_idx, max_seq_length, short_seq_prob,
|
doc_database, doc_idx, max_seq_length, short_seq_prob,
|
||||||
masked_lm_prob, max_predictions_per_seq, vocab_list):
|
masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list):
|
||||||
"""This code is mostly a duplicate of the equivalent function from Google BERT's repo.
|
"""This code is mostly a duplicate of the equivalent function from Google BERT's repo.
|
||||||
However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function.
|
However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function.
|
||||||
Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence
|
Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence
|
||||||
@ -213,7 +248,7 @@ def create_instances_from_document(
|
|||||||
segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)]
|
segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)]
|
||||||
|
|
||||||
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(
|
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(
|
||||||
tokens, masked_lm_prob, max_predictions_per_seq, vocab_list)
|
tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list)
|
||||||
|
|
||||||
instance = {
|
instance = {
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
@ -237,7 +272,8 @@ def main():
|
|||||||
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
|
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
|
||||||
"bert-base-multilingual", "bert-base-chinese"])
|
"bert-base-multilingual", "bert-base-chinese"])
|
||||||
parser.add_argument("--do_lower_case", action="store_true")
|
parser.add_argument("--do_lower_case", action="store_true")
|
||||||
|
parser.add_argument("--do_whole_word_mask", action="store_true",
|
||||||
|
help="Whether to use whole word masking rather than per-WordPiece masking.")
|
||||||
parser.add_argument("--reduce_memory", action="store_true",
|
parser.add_argument("--reduce_memory", action="store_true",
|
||||||
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
||||||
|
|
||||||
@ -284,7 +320,7 @@ def main():
|
|||||||
doc_instances = create_instances_from_document(
|
doc_instances = create_instances_from_document(
|
||||||
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
||||||
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
||||||
vocab_list=vocab_list)
|
whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
|
||||||
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
||||||
for instance in doc_instances:
|
for instance in doc_instances:
|
||||||
epoch_file.write(instance + '\n')
|
epoch_file.write(instance + '\n')
|
||||||
|
Loading…
Reference in New Issue
Block a user