mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
Fix tr_loss rescaling factor using global_step
This commit is contained in:
parent
ed8fad7390
commit
87b9ec3843
@ -845,7 +845,7 @@ def main():
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item() * args.gradient_accumulation_steps
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
@ -936,7 +936,7 @@ def main():
|
||||
elif output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(task_name, preds, all_label_ids.numpy())
|
||||
loss = tr_loss/nb_tr_steps if args.do_train else None
|
||||
loss = tr_loss/global_step if args.do_train else None
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
@ -1004,7 +1004,7 @@ def main():
|
||||
preds = preds[0]
|
||||
preds = np.argmax(preds, axis=1)
|
||||
result = compute_metrics(task_name, preds, all_label_ids.numpy())
|
||||
loss = tr_loss/nb_tr_steps if args.do_train else None
|
||||
loss = tr_loss/global_step if args.do_train else None
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
|
@ -452,7 +452,7 @@ def main():
|
||||
loss = loss * args.loss_scale
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
tr_loss += loss.item() * args.gradient_accumulation_steps
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
|
||||
@ -537,7 +537,7 @@ def main():
|
||||
result = {'eval_loss': eval_loss,
|
||||
'eval_accuracy': eval_accuracy,
|
||||
'global_step': global_step,
|
||||
'loss': tr_loss/nb_tr_steps}
|
||||
'loss': tr_loss/global_step}
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
|
Loading…
Reference in New Issue
Block a user