Merge pull request #478 from Rocketknight1/master

Added a helpful error for users with single-document corpuses - fixes # 452
This commit is contained in:
Thomas Wolf 2019-04-15 10:55:57 +02:00 committed by GitHub
commit 7873d76464
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ from tqdm import tqdm, trange
from tempfile import TemporaryDirectory
import shelve
from random import random, randint, shuffle, choice, sample
from random import random, randrange, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np
import json
@ -30,6 +30,8 @@ class DocumentDatabase:
self.reduce_memory = reduce_memory
def add_document(self, document):
if not document:
return
if self.reduce_memory:
current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document
@ -49,11 +51,11 @@ class DocumentDatabase:
self._precalculate_doc_weights()
rand_start = self.doc_cumsum[current_idx]
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max
sentence_index = randrange(rand_start, rand_end) % self.cumsum_max
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths)
assert sampled_doc_index != current_idx
if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)]
@ -170,7 +172,7 @@ def create_instances_from_document(
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = randint(1, len(current_chunk) - 1)
a_end = randrange(1, len(current_chunk))
tokens_a = []
for j in range(a_end):
@ -186,7 +188,7 @@ def create_instances_from_document(
# Sample a random document, with longer docs being sampled more frequently
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
random_start = randint(0, len(random_document) - 1)
random_start = randrange(0, len(random_document))
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
@ -264,6 +266,14 @@ def main():
else:
tokens = tokenizer.tokenize(line)
doc.append(tokens)
if doc:
docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added
if len(docs) <= 1:
exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
"sections or paragraphs.")
args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"):