diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 8b395be539b..e3406d4d9bc 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -590,7 +590,7 @@ if __name__ == "__main__": # Create learning rate scheduler # warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent. lr_scheduler_fn = create_learning_rate_scheduler( - base_learning_rate=training_args.learning_rate, warmup_steps=min(training_args.warmup_steps, 1) + base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1) ) # Create parallel version of the training and evaluation steps