Added a --reduce_memory option to shelve docs to disc instead of keeping them in memory.

This commit is contained in:
Matthew Carrigan 2019-03-21 16:50:16 +00:00
parent 8733ffcb5e
commit 2bba7f810e

View File

@ -1,46 +1,81 @@
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm, trange
from tempfile import TemporaryDirectory
import shelve
from random import random, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np
import json
class DocumentDatabase:
def __init__(self, document_list):
self.document_list = document_list
self.doc_starts = {}
self.weighted_doc_samples = []
i = 0
for doc_idx, doc in enumerate(document_list):
self.doc_starts[doc_idx] = i
self.weighted_doc_samples.extend([doc_idx] * len(doc))
i += len(doc)
def __init__(self, reduce_memory=False, working_dir=None):
if reduce_memory:
if working_dir is None:
self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name)
else:
self.temp_dir = None
self.working_dir = Path(working_dir)
self.working_dir.mkdir(parents=True, exist_ok=True)
self.document_shelf_filepath = self.working_dir / 'shelf.db'
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
flag='n', protocol=-1)
self.documents = None
else:
self.documents = []
self.document_shelf = None
self.document_shelf_filepath = None
self.doc_lengths = []
self.doc_cumsum = None
self.cumsum_max = None
self.reduce_memory = reduce_memory
def add_document(self, document):
if self.reduce_memory:
current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document
else:
self.documents.append(document)
self.doc_lengths.append(len(document))
def _precalculate_doc_weights(self):
self.doc_cumsum = np.cumsum(self.doc_lengths)
self.cumsum_max = self.doc_cumsum[-1]
def sample_doc(self, current_idx, sentence_weighted=True):
# Uses the current iteration counter to ensure we don't sample the same doc twice
if sentence_weighted:
num_sentences = len(self.document_list[current_idx])
# This very painful line randomly selects a document, weighted by the number of sentences they contain,
# while guaranteeing that it won't return the original document
sampled_val = (
(self.doc_starts[current_idx] + num_sentences
+ randint(0, len(self.weighted_doc_samples) - num_sentences - 1))
% len(self.weighted_doc_samples))
sampled_doc_index = self.weighted_doc_samples[sampled_val]
# With sentence weighting, we sample docs proportionally to their sentence length
if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths):
self._precalculate_doc_weights()
rand_start = self.doc_cumsum[current_idx]
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
sentence_index = randint(rand_start, rand_end) % self.cumsum_max
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.document_list)-1)
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
assert sampled_doc_index != current_idx
return self.document_list[sampled_doc_index]
if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)]
else:
return self.documents[sampled_doc_index]
def __len__(self):
return len(self.document_list)
return len(self.doc_lengths)
def __getitem__(self, item):
return self.document_list[item]
if self.reduce_memory:
return self.document_shelf[str(item)]
else:
return self.documents[item]
def cleanup(self):
if self.document_shelf is not None:
self.document_shelf.close()
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
@ -200,6 +235,11 @@ def main():
"bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
parser.add_argument("--working_dir", type=Path, default=None,
help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()")
parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate")
parser.add_argument("--max_seq_len", type=int, default=128)
@ -212,31 +252,21 @@ def main():
args = parser.parse_args()
# TODO Add a low-memory / multiprocessing path for very large datasets
# In this path documents would be stored in a shelf after being tokenized, and multiple processes would convert
# those docs into training examples that would be written out on the fly. This would avoid the need to keep
# the whole training set in memory and would speed up dataset creation at the cost of code complexity.
# In addition, the finetuning script would need to be modified
# to store the training epochs as memmapped arrays.
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys())
docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir)
with args.train_corpus.open() as f:
docs = []
doc = []
for line in tqdm(f, desc="Loading Dataset"):
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
line = line.strip()
if line == "":
docs.append(doc)
docs.add_document(doc)
doc = []
else:
tokens = tokenizer.tokenize(line)
doc.append(tokens)
args.output_dir.mkdir(exist_ok=True)
docs = DocumentDatabase(docs)
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
# Google BERT doesn't do this, and as a result oversamples shorter docs
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
num_instances = 0
@ -257,6 +287,7 @@ def main():
"max_seq_len": args.max_seq_len
}
metrics_file.write(json.dumps(metrics))
docs.cleanup()
if __name__ == '__main__':