mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #733 from ceremonious/parallel-generation
Added option to use multiple workers to create training data
This commit is contained in:
commit
78462aad61
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
from tqdm import tqdm, trange
|
||||
from tempfile import TemporaryDirectory
|
||||
import shelve
|
||||
from multiprocessing import Pool
|
||||
|
||||
from random import random, randrange, randint, shuffle, choice
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
@ -264,6 +265,28 @@ def create_instances_from_document(
|
||||
return instances
|
||||
|
||||
|
||||
def create_training_file(docs, vocab_list, args, epoch_num):
|
||||
epoch_filename = args.output_dir / "epoch_{}.json".format(epoch_num)
|
||||
num_instances = 0
|
||||
with epoch_filename.open('w') as epoch_file:
|
||||
for doc_idx in trange(len(docs), desc="Document"):
|
||||
doc_instances = create_instances_from_document(
|
||||
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
||||
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
||||
whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
|
||||
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
||||
for instance in doc_instances:
|
||||
epoch_file.write(instance + '\n')
|
||||
num_instances += 1
|
||||
metrics_file = args.output_dir / "epoch_{}_metrics.json".format(epoch_num)
|
||||
with metrics_file.open('w') as metrics_file:
|
||||
metrics = {
|
||||
"num_training_examples": num_instances,
|
||||
"max_seq_len": args.max_seq_len
|
||||
}
|
||||
metrics_file.write(json.dumps(metrics))
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--train_corpus', type=Path, required=True)
|
||||
@ -277,6 +300,8 @@ def main():
|
||||
parser.add_argument("--reduce_memory", action="store_true",
|
||||
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
|
||||
|
||||
parser.add_argument("--num_workers", type=int, default=1,
|
||||
help="The number of workers to use to write the files")
|
||||
parser.add_argument("--epochs_to_generate", type=int, default=3,
|
||||
help="Number of epochs of data to pregenerate")
|
||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||
@ -289,6 +314,9 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.num_workers > 1 and args.reduce_memory:
|
||||
raise ValueError("Cannot use multiple workers while reducing memory")
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
vocab_list = list(tokenizer.vocab.keys())
|
||||
with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
|
||||
@ -312,26 +340,14 @@ def main():
|
||||
"sections or paragraphs.")
|
||||
|
||||
args.output_dir.mkdir(exist_ok=True)
|
||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
|
||||
num_instances = 0
|
||||
with epoch_filename.open('w') as epoch_file:
|
||||
for doc_idx in trange(len(docs), desc="Document"):
|
||||
doc_instances = create_instances_from_document(
|
||||
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
|
||||
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
|
||||
whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
|
||||
doc_instances = [json.dumps(instance) for instance in doc_instances]
|
||||
for instance in doc_instances:
|
||||
epoch_file.write(instance + '\n')
|
||||
num_instances += 1
|
||||
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
||||
with metrics_file.open('w') as metrics_file:
|
||||
metrics = {
|
||||
"num_training_examples": num_instances,
|
||||
"max_seq_len": args.max_seq_len
|
||||
}
|
||||
metrics_file.write(json.dumps(metrics))
|
||||
|
||||
if args.num_workers > 1:
|
||||
writer_workers = Pool(min(args.num_workers, args.epochs_to_generate))
|
||||
arguments = [(docs, vocab_list, args, idx) for idx in range(args.epochs_to_generate)]
|
||||
writer_workers.starmap(create_training_file, arguments)
|
||||
else:
|
||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||
create_training_file(docs, vocab_list, args, epoch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user