mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Refactor logs and fix loss bug
This commit is contained in:
parent
05d4232f63
commit
41aa0e8003
@ -171,22 +171,22 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Log metrics
|
logs = {}
|
||||||
logs = {'step': global_step}
|
|
||||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
eval_key = 'eval_{}'.format(key)
|
eval_key = 'eval_{}'.format(key)
|
||||||
tb_writer.add_scalar(eval_key, value, global_step)
|
logs[eval_key] = value
|
||||||
logs[eval_key] = str(value)
|
|
||||||
logging_loss = tr_loss
|
|
||||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||||
learning_rate_scalar = scheduler.get_lr()[0]
|
learning_rate_scalar = scheduler.get_lr()[0]
|
||||||
tb_writer.add_scalar('lr', learning_rate_scalar, global_step)
|
|
||||||
tb_writer.add_scalar('loss', loss_scalar, global_step)
|
|
||||||
logs['learning_rate'] = learning_rate_scalar
|
logs['learning_rate'] = learning_rate_scalar
|
||||||
logs['loss'] = loss_scalar
|
logs['loss'] = loss_scalar
|
||||||
print(json.dumps(logs))
|
logging_loss = tr_loss
|
||||||
|
|
||||||
|
for key, value in logs.items():
|
||||||
|
tb_writer.add_scalar(key, value, global_step)
|
||||||
|
print(json.dumps({**logs, **{'step': global_step}}))
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
|
Loading…
Reference in New Issue
Block a user