diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index f90ac2bb389..ceb2b71624d 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -598,7 +598,7 @@ def main(): state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) train_metrics.append(train_metric) - cur_step = epoch * step_per_epoch + step + cur_step = (epoch * step_per_epoch) + (step + 1) if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics