mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix step shifting when accumulate gradient (#33673)
* replace total_batched_samples with step while counting grad accum step * remove unused variable * simplify condition for update step * fix format by ruff * simplify update step condition using accelerator.sync_gradients * simplify update condition using do_sync_step * remove print for test --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
parent
1b86772de5
commit
dca93ca076
@ -2404,7 +2404,6 @@ class Trainer:
|
||||
if args.eval_on_start:
|
||||
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
|
||||
|
||||
total_batched_samples = 0
|
||||
for epoch in range(epochs_trained, num_train_epochs):
|
||||
epoch_dataloader = train_dataloader
|
||||
if hasattr(epoch_dataloader, "set_epoch"):
|
||||
@ -2447,13 +2446,7 @@ class Trainer:
|
||||
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
|
||||
for inputs in batch_samples:
|
||||
step += 1
|
||||
total_batched_samples += 1
|
||||
is_last_step_and_steps_less_than_grad_acc = (
|
||||
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
|
||||
)
|
||||
do_sync_step = is_last_step_and_steps_less_than_grad_acc or (
|
||||
total_batched_samples % args.gradient_accumulation_steps == 0
|
||||
)
|
||||
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
|
||||
# Since we perform prefetching, we need to manually set sync_gradients
|
||||
if not do_sync_step:
|
||||
self.accelerator.gradient_state._set_sync_gradients(False)
|
||||
|
Loading…
Reference in New Issue
Block a user