diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 927adcb3955..fe15d57bb11 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2434,6 +2434,8 @@ class Trainer: remainder = args.gradient_accumulation_steps update_step = -1 total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 + if args.gradient_accumulation_steps == 1: + total_updates -= 1 for _ in range(total_updates): update_step += 1 num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder