Bring back set_epoch for Accelerate-based dataloaders (#26850)

* Working tests!

* Fix sampler

* Fix

* Update src/transformers/trainer.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix check

* Clean

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Zach Mueller 2023-10-26 05:20:11 -04:00 committed by GitHub
parent 3c2692407d
commit 90412401e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -200,6 +200,11 @@ if is_accelerate_available():
save_fsdp_model,
save_fsdp_optimizer,
)
DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler
DATA_SAMPLERS += [SeedableRandomSampler]
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
@ -1738,7 +1743,10 @@ class Trainer:
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
sampler = get_dataloader_sampler(train_dataloader)
is_random_sampler = isinstance(sampler, RandomSampler)
sampler_kinds = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
sampler_kinds.append(SeedableRandomSampler)
is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
@ -1752,6 +1760,8 @@ class Trainer:
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
if hasattr(epoch_iterator, "set_epoch"):
epoch_iterator.set_epoch(epoch)
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0: