This commit is contained in:
thomwolf 2019-06-18 14:45:14 +02:00
parent a432b3d466
commit e6e5f19257

View File

@ -201,7 +201,7 @@ def main():
if args.do_train:
if args.local_rank in [-1, 0]:
writer = SummaryWriter()
tb_writer = SummaryWriter()
# Prepare data loader
train_examples = read_squad_examples(
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
@ -302,8 +302,8 @@ def main():
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.local_rank in [-1, 0]:
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
writer.add_scalar('loss', loss.item(), global_step)
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step)
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used and handles this automatically