[Flax] Fix cur step flax examples (#12608)

* fix_torch_device_generate_test

* remove @

* fix save problem
This commit is contained in:
Patrick von Platen 2021-07-09 13:51:28 +01:00 committed by GitHub
parent 65e27215ba
commit deecdd4939
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 0 additions and 3 deletions

View File

@ -622,7 +622,6 @@ def main():
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(train_dataset) // train_batch_size)
write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0:

View File

@ -663,7 +663,6 @@ if __name__ == "__main__":
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0:

View File

@ -771,7 +771,6 @@ if __name__ == "__main__":
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0: