mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
fixing
This commit is contained in:
parent
92e0ad5aba
commit
16a1f338c4
@ -222,6 +222,10 @@ def main():
|
|||||||
elif n_gpu > 1:
|
elif n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
|
global_step = 0
|
||||||
|
nb_tr_steps = 0
|
||||||
|
tr_loss = 0
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
tb_writer = SummaryWriter()
|
tb_writer = SummaryWriter()
|
||||||
@ -293,10 +297,6 @@ def main():
|
|||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=num_train_optimization_steps)
|
t_total=num_train_optimization_steps)
|
||||||
|
|
||||||
global_step = 0
|
|
||||||
nb_tr_steps = 0
|
|
||||||
tr_loss = 0
|
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_examples))
|
logger.info(" Num examples = %d", len(train_examples))
|
||||||
logger.info(" Batch size = %d", args.train_batch_size)
|
logger.info(" Batch size = %d", args.train_batch_size)
|
||||||
|
Loading…
Reference in New Issue
Block a user