mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-21 13:38:31 +06:00
Small typo in trange
I seriously don't understand why they defined num_train_epochs as a float in the originial tf code. I Will change it at the end to avoir merge conflicts for now.
This commit is contained in:
parent
5676d6f799
commit
391a4ec2f3
@ -514,7 +514,7 @@ def main():
|
|||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
nb_tr_examples = 0
|
nb_tr_examples = 0
|
||||||
for epoch in trange(args.num_train_epochs, desc="Epoch"):
|
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
|
for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.float().to(device)
|
||||||
|
Loading…
Reference in New Issue
Block a user