mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Syncing up argument names between the scripts
This commit is contained in:
parent
f19ba35b2b
commit
934d3f4d2f
@ -201,8 +201,8 @@ def create_instances_from_document(
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--corpus_path', type=Path, required=True)
|
||||
parser.add_argument("--save_dir", type=Path, required=True)
|
||||
parser.add_argument('--train_corpus', type=Path, required=True)
|
||||
parser.add_argument("--output_dir", type=Path, required=True)
|
||||
parser.add_argument("--bert_model", type=str, required=True,
|
||||
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
|
||||
"bert-base-multilingual", "bert-base-chinese"])
|
||||
@ -229,7 +229,7 @@ def main():
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
vocab_list = list(tokenizer.vocab.keys())
|
||||
with args.corpus_path.open() as f:
|
||||
with args.train_corpus.open() as f:
|
||||
docs = []
|
||||
doc = []
|
||||
for line in tqdm(f, desc="Loading Dataset"):
|
||||
@ -241,7 +241,7 @@ def main():
|
||||
tokens = tokenizer.tokenize(line)
|
||||
doc.append(tokens)
|
||||
|
||||
args.save_dir.mkdir(exist_ok=True)
|
||||
args.output_dir.mkdir(exist_ok=True)
|
||||
docs = DocumentDatabase(docs)
|
||||
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
|
||||
# Google BERT doesn't do this, and as a result oversamples shorter docs
|
||||
@ -256,8 +256,8 @@ def main():
|
||||
epoch_instances.extend(doc_instances)
|
||||
|
||||
shuffle(epoch_instances)
|
||||
epoch_file = args.save_dir / f"epoch_{epoch}.json"
|
||||
metrics_file = args.save_dir / f"epoch_{epoch}_metrics.json"
|
||||
epoch_file = args.output_dir / f"epoch_{epoch}.json"
|
||||
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
|
||||
with epoch_file.open('w') as out_file:
|
||||
for instance in epoch_instances:
|
||||
out_file.write(instance + '\n')
|
||||
|
@ -401,7 +401,7 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--train_file",
|
||||
parser.add_argument("--train_corpus",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
@ -511,8 +511,8 @@ def main():
|
||||
#train_examples = None
|
||||
num_train_optimization_steps = None
|
||||
if args.do_train:
|
||||
print("Loading Train Dataset", args.train_file)
|
||||
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
|
||||
print("Loading Train Dataset", args.train_corpus)
|
||||
train_dataset = BERTDataset(args.train_corpus, tokenizer, seq_len=args.max_seq_length,
|
||||
corpus_lines=None, on_memory=args.on_memory)
|
||||
num_train_optimization_steps = int(
|
||||
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||
|
Loading…
Reference in New Issue
Block a user