mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix iterator overflow when gradient accumulation is 1 (#35960)
This commit is contained in:
parent
4d3b1076a1
commit
7547f55e5d
@ -2434,6 +2434,8 @@ class Trainer:
|
||||
remainder = args.gradient_accumulation_steps
|
||||
update_step = -1
|
||||
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
|
||||
if args.gradient_accumulation_steps == 1:
|
||||
total_updates -= 1
|
||||
for _ in range(total_updates):
|
||||
update_step += 1
|
||||
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
||||
|
Loading…
Reference in New Issue
Block a user