mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
only on main process
This commit is contained in:
parent
326944d627
commit
335f57baf8
@ -917,7 +917,8 @@ def main():
|
|||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
writer = SummaryWriter()
|
if args.local_rank in [-1, 0]:
|
||||||
|
writer = SummaryWriter()
|
||||||
# Prepare data loader
|
# Prepare data loader
|
||||||
train_examples = read_squad_examples(
|
train_examples = read_squad_examples(
|
||||||
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
|
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
|
||||||
@ -1016,8 +1017,9 @@ def main():
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
if args.local_rank in [-1, 0]:
|
||||||
writer.add_scalar('loss', loss.item(), global_step)
|
writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||||
|
writer.add_scalar('loss', loss.item(), global_step)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
# modify learning rate with special warm up BERT uses
|
# modify learning rate with special warm up BERT uses
|
||||||
# if args.fp16 is False, BertAdam is used and handles this automatically
|
# if args.fp16 is False, BertAdam is used and handles this automatically
|
||||||
|
Loading…
Reference in New Issue
Block a user