mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Added a --reduce_memory option to shelve docs to disc instead of keeping them in memory.
This commit is contained in:
parent
8733ffcb5e
commit
2bba7f810e
@ -1,46 +1,81 @@
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm, trange
|
||||
from tempfile import TemporaryDirectory
|
||||
import shelve
|
||||
|
||||
from random import random, randint, shuffle, choice, sample
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
class DocumentDatabase:
|
||||
def __init__(self, document_list):
|
||||
self.document_list = document_list
|
||||
self.doc_starts = {}
|
||||
self.weighted_doc_samples = []
|
||||
i = 0
|
||||
for doc_idx, doc in enumerate(document_list):
|
||||
self.doc_starts[doc_idx] = i
|
||||
self.weighted_doc_samples.extend([doc_idx] * len(doc))
|
||||
i += len(doc)
|
||||
def __init__(self, reduce_memory=False, working_dir=None):
|
||||
if reduce_memory:
|
||||
if working_dir is None:
|
||||
self.temp_dir = TemporaryDirectory()
|
||||
self.working_dir = Path(self.temp_dir.name)
|
||||
else:
|
||||
self.temp_dir = None
|
||||
self.working_dir = Path(working_dir)
|
||||
self.working_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.document_shelf_filepath = self.working_dir / 'shelf.db'
|
||||
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
|
||||
flag='n', protocol=-1)
|
||||
self.documents = None
|
||||
else:
|
||||
self.documents = []
|
||||
self.document_shelf = None
|
||||
self.document_shelf_filepath = None
|
||||
self.doc_lengths = []
|
||||
self.doc_cumsum = None
|
||||
self.cumsum_max = None
|
||||
self.reduce_memory = reduce_memory
|
||||
|
||||
def add_document(self, document):
|
||||
if self.reduce_memory:
|
||||
current_idx = len(self.doc_lengths)
|
||||
self.document_shelf[str(current_idx)] = document
|
||||
else:
|
||||
self.documents.append(document)
|
||||
self.doc_lengths.append(len(document))
|
||||
|
||||
def _precalculate_doc_weights(self):
|
||||
self.doc_cumsum = np.cumsum(self.doc_lengths)
|
||||
self.cumsum_max = self.doc_cumsum[-1]
|
||||
|
||||
def sample_doc(self, current_idx, sentence_weighted=True):
|
||||
# Uses the current iteration counter to ensure we don't sample the same doc twice
|
||||
if sentence_weighted:
|
||||
num_sentences = len(self.document_list[current_idx])
|
||||
# This very painful line randomly selects a document, weighted by the number of sentences they contain,
|
||||
# while guaranteeing that it won't return the original document
|
||||
sampled_val = (
|
||||
(self.doc_starts[current_idx] + num_sentences
|
||||
+ randint(0, len(self.weighted_doc_samples) - num_sentences - 1))
|
||||
% len(self.weighted_doc_samples))
|
||||
sampled_doc_index = self.weighted_doc_samples[sampled_val]
|
||||
# With sentence weighting, we sample docs proportionally to their sentence length
|
||||
if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths):
|
||||
self._precalculate_doc_weights()
|
||||
rand_start = self.doc_cumsum[current_idx]
|
||||
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
|
||||
sentence_index = randint(rand_start, rand_end) % self.cumsum_max
|
||||
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
|
||||
else:
|
||||
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
|
||||
sampled_doc_index = current_idx + randint(1, len(self.document_list)-1)
|
||||
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
|
||||
assert sampled_doc_index != current_idx
|
||||
return self.document_list[sampled_doc_index]
|
||||
if self.reduce_memory:
|
||||
return self.document_shelf[str(sampled_doc_index)]
|
||||
else:
|
||||
return self.documents[sampled_doc_index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.document_list)
|
||||
return len(self.doc_lengths)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.document_list[item]
|
||||
if self.reduce_memory:
|
||||
return self.document_shelf[str(item)]
|
||||
else:
|
||||
return self.documents[item]
|
||||
|
||||
def cleanup(self):
|
||||
if self.document_shelf is not None:
|
||||
self.document_shelf.close()
|
||||
|
||||
|
||||
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
||||
@ -200,6 +235,11 @@ def main():
|
||||
"bert-base-multilingual", "bert-base-chinese"])
|
||||
parser.add_argument("--do_lower_case", action="store_true")
|
||||
|
||||
parser.add_argument("--reduce_memory", action="store_true",
|
||||
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
||||
parser.add_argument("--working_dir", type=Path, default=None,
|
||||
help="Temporary directory to use for --reduce_memory. If not set, uses TemporaryDirectory()")
|
||||
|
||||
parser.add_argument("--epochs_to_generate", type=int, default=3,
|
||||
help="Number of epochs of data to pregenerate")
|
||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||
@ -212,31 +252,21 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# TODO Add a low-memory / multiprocessing path for very large datasets
|
||||
# In this path documents would be stored in a shelf after being tokenized, and multiple processes would convert
|
||||
# those docs into training examples that would be written out on the fly. This would avoid the need to keep
|
||||
# the whole training set in memory and would speed up dataset creation at the cost of code complexity.
|
||||
# In addition, the finetuning script would need to be modified
|
||||
# to store the training epochs as memmapped arrays.
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
vocab_list = list(tokenizer.vocab.keys())
|
||||
docs = DocumentDatabase(reduce_memory=args.reduce_memory, working_dir=args.working_dir)
|
||||
with args.train_corpus.open() as f:
|
||||
docs = []
|
||||
doc = []
|
||||
for line in tqdm(f, desc="Loading Dataset"):
|
||||
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
|
||||
line = line.strip()
|
||||
if line == "":
|
||||
docs.append(doc)
|
||||
docs.add_document(doc)
|
||||
doc = []
|
||||
else:
|
||||
tokens = tokenizer.tokenize(line)
|
||||
doc.append(tokens)
|
||||
|
||||
args.output_dir.mkdir(exist_ok=True)
|
||||
docs = DocumentDatabase(docs)
|
||||
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
|
||||
# Google BERT doesn't do this, and as a result oversamples shorter docs
|
||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
||||
num_instances = 0
|
||||
@ -257,6 +287,7 @@ def main():
|
||||
"max_seq_len": args.max_seq_len
|
||||
}
|
||||
metrics_file.write(json.dumps(metrics))
|
||||
docs.cleanup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user