clean up distributed training logging in run_squad example

This commit is contained in:
thomwolf 2019-04-15 15:27:10 +02:00
parent 1135f2384a
commit 7816f7921f

View File

@ -985,7 +985,7 @@ def main():
model.train() model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch"):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
if n_gpu == 1: if n_gpu == 1:
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
input_ids, input_mask, segment_ids, start_positions, end_positions = batch input_ids, input_mask, segment_ids, start_positions, end_positions = batch
@ -1058,7 +1058,7 @@ def main():
model.eval() model.eval()
all_results = [] all_results = []
logger.info("Start evaluating") logger.info("Start evaluating")
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]):
if len(all_results) % 1000 == 0: if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results))) logger.info("Processing example: %d" % (len(all_results)))
input_ids = input_ids.to(device) input_ids = input_ids.to(device)