mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
only on main process
This commit is contained in:
parent
326944d627
commit
335f57baf8
@ -917,7 +917,8 @@ def main():
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
if args.do_train:
|
||||
writer = SummaryWriter()
|
||||
if args.local_rank in [-1, 0]:
|
||||
writer = SummaryWriter()
|
||||
# Prepare data loader
|
||||
train_examples = read_squad_examples(
|
||||
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
|
||||
@ -1016,8 +1017,9 @@ def main():
|
||||
else:
|
||||
loss.backward()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||
writer.add_scalar('loss', loss.item(), global_step)
|
||||
if args.local_rank in [-1, 0]:
|
||||
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||
writer.add_scalar('loss', loss.item(), global_step)
|
||||
if args.fp16:
|
||||
# modify learning rate with special warm up BERT uses
|
||||
# if args.fp16 is False, BertAdam is used and handles this automatically
|
||||
|
Loading…
Reference in New Issue
Block a user