mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #478 from Rocketknight1/master
Added a helpful error for users with single-document corpuses - fixes # 452
This commit is contained in:
commit
7873d76464
@ -4,7 +4,7 @@ from tqdm import tqdm, trange
|
||||
from tempfile import TemporaryDirectory
|
||||
import shelve
|
||||
|
||||
from random import random, randint, shuffle, choice, sample
|
||||
from random import random, randrange, randint, shuffle, choice, sample
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
import numpy as np
|
||||
import json
|
||||
@ -30,6 +30,8 @@ class DocumentDatabase:
|
||||
self.reduce_memory = reduce_memory
|
||||
|
||||
def add_document(self, document):
|
||||
if not document:
|
||||
return
|
||||
if self.reduce_memory:
|
||||
current_idx = len(self.doc_lengths)
|
||||
self.document_shelf[str(current_idx)] = document
|
||||
@ -49,11 +51,11 @@ class DocumentDatabase:
|
||||
self._precalculate_doc_weights()
|
||||
rand_start = self.doc_cumsum[current_idx]
|
||||
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
|
||||
sentence_index = randint(rand_start, rand_end-1) % self.cumsum_max
|
||||
sentence_index = randrange(rand_start, rand_end) % self.cumsum_max
|
||||
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
|
||||
else:
|
||||
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
|
||||
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
|
||||
sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths)
|
||||
assert sampled_doc_index != current_idx
|
||||
if self.reduce_memory:
|
||||
return self.document_shelf[str(sampled_doc_index)]
|
||||
@ -170,7 +172,7 @@ def create_instances_from_document(
|
||||
# (first) sentence.
|
||||
a_end = 1
|
||||
if len(current_chunk) >= 2:
|
||||
a_end = randint(1, len(current_chunk) - 1)
|
||||
a_end = randrange(1, len(current_chunk))
|
||||
|
||||
tokens_a = []
|
||||
for j in range(a_end):
|
||||
@ -186,7 +188,7 @@ def create_instances_from_document(
|
||||
# Sample a random document, with longer docs being sampled more frequently
|
||||
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
|
||||
|
||||
random_start = randint(0, len(random_document) - 1)
|
||||
random_start = randrange(0, len(random_document))
|
||||
for j in range(random_start, len(random_document)):
|
||||
tokens_b.extend(random_document[j])
|
||||
if len(tokens_b) >= target_b_length:
|
||||
@ -264,6 +266,14 @@ def main():
|
||||
else:
|
||||
tokens = tokenizer.tokenize(line)
|
||||
doc.append(tokens)
|
||||
if doc:
|
||||
docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added
|
||||
if len(docs) <= 1:
|
||||
exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
|
||||
"ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
|
||||
"indicate breaks between documents in your input file. If your dataset does not contain multiple "
|
||||
"documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
|
||||
"sections or paragraphs.")
|
||||
|
||||
args.output_dir.mkdir(exist_ok=True)
|
||||
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
|
||||
|
Loading…
Reference in New Issue
Block a user