mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
clean up distributed training logging in run_squad example
This commit is contained in:
parent
1135f2384a
commit
7816f7921f
@ -985,7 +985,7 @@ def main():
|
||||
|
||||
model.train()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||
if n_gpu == 1:
|
||||
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
|
||||
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
|
||||
@ -1058,7 +1058,7 @@ def main():
|
||||
model.eval()
|
||||
all_results = []
|
||||
logger.info("Start evaluating")
|
||||
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]):
|
||||
if len(all_results) % 1000 == 0:
|
||||
logger.info("Processing example: %d" % (len(all_results)))
|
||||
input_ids = input_ids.to(device)
|
||||
|
Loading…
Reference in New Issue
Block a user