Syncing up argument names between the scripts

This commit is contained in:
Matthew Carrigan 2019-03-20 17:23:23 +00:00
parent f19ba35b2b
commit 934d3f4d2f
2 changed files with 9 additions and 9 deletions

View File

@ -201,8 +201,8 @@ def create_instances_from_document(
def main():
parser = ArgumentParser()
parser.add_argument('--corpus_path', type=Path, required=True)
parser.add_argument("--save_dir", type=Path, required=True)
parser.add_argument('--train_corpus', type=Path, required=True)
parser.add_argument("--output_dir", type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True,
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"])
@ -229,7 +229,7 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys())
with args.corpus_path.open() as f:
with args.train_corpus.open() as f:
docs = []
doc = []
for line in tqdm(f, desc="Loading Dataset"):
@ -241,7 +241,7 @@ def main():
tokens = tokenizer.tokenize(line)
doc.append(tokens)
args.save_dir.mkdir(exist_ok=True)
args.output_dir.mkdir(exist_ok=True)
docs = DocumentDatabase(docs)
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
# Google BERT doesn't do this, and as a result oversamples shorter docs
@ -256,8 +256,8 @@ def main():
epoch_instances.extend(doc_instances)
shuffle(epoch_instances)
epoch_file = args.save_dir / f"epoch_{epoch}.json"
metrics_file = args.save_dir / f"epoch_{epoch}_metrics.json"
epoch_file = args.output_dir / f"epoch_{epoch}.json"
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
with epoch_file.open('w') as out_file:
for instance in epoch_instances:
out_file.write(instance + '\n')

View File

@ -401,7 +401,7 @@ def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--train_file",
parser.add_argument("--train_corpus",
default=None,
type=str,
required=True,
@ -511,8 +511,8 @@ def main():
#train_examples = None
num_train_optimization_steps = None
if args.do_train:
print("Loading Train Dataset", args.train_file)
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
print("Loading Train Dataset", args.train_corpus)
train_dataset = BERTDataset(args.train_corpus, tokenizer, seq_len=args.max_seq_length,
corpus_lines=None, on_memory=args.on_memory)
num_train_optimization_steps = int(
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs