From 391a4ec2f3b3b77fbe3e42e6df6a82d6e472cf88 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Sat, 3 Nov 2018 08:25:15 -0400 Subject: [PATCH] Small typo in `trange` I seriously don't understand why they defined num_train_epochs as a float in the originial tf code. I Will change it at the end to avoir merge conflicts for now. --- run_classifier_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 426bff64c9c..5d283d3415a 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -514,7 +514,7 @@ def main(): model.train() nb_tr_examples = 0 - for epoch in trange(args.num_train_epochs, desc="Epoch"): + for epoch in trange(int(args.num_train_epochs), desc="Epoch"): for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"): input_ids = input_ids.to(device) input_mask = input_mask.float().to(device)