[Flax] token-classification model steps enumerate start from 1 (#14547)

* step start from 1

* Updated cur_step calcualtion
This commit is contained in:
Kamal Raj 2021-11-29 21:55:59 +05:30 committed by GitHub
parent cea17acd8c
commit 2bd950ca47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -598,7 +598,7 @@ def main():
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
train_metrics.append(train_metric)
cur_step = epoch * step_per_epoch + step
cur_step = (epoch * step_per_epoch) + (step + 1)
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics