mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Bring back set_epoch
for Accelerate-based dataloaders (#26850)
* Working tests! * Fix sampler * Fix * Update src/transformers/trainer.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix check * Clean --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
3c2692407d
commit
90412401e6
@ -200,6 +200,11 @@ if is_accelerate_available():
|
||||
save_fsdp_model,
|
||||
save_fsdp_optimizer,
|
||||
)
|
||||
DATA_SAMPLERS = [RandomSampler]
|
||||
if version.parse(accelerate_version) > version.parse("0.23.0"):
|
||||
from accelerate.data_loader import SeedableRandomSampler
|
||||
|
||||
DATA_SAMPLERS += [SeedableRandomSampler]
|
||||
|
||||
if is_deepspeed_available():
|
||||
from accelerate.utils import DeepSpeedSchedulerWrapper
|
||||
@ -1738,7 +1743,10 @@ class Trainer:
|
||||
if not args.ignore_data_skip:
|
||||
for epoch in range(epochs_trained):
|
||||
sampler = get_dataloader_sampler(train_dataloader)
|
||||
is_random_sampler = isinstance(sampler, RandomSampler)
|
||||
sampler_kinds = [RandomSampler]
|
||||
if version.parse(accelerate_version) > version.parse("0.23.0"):
|
||||
sampler_kinds.append(SeedableRandomSampler)
|
||||
is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
|
||||
if is_torch_less_than_1_11 or not is_random_sampler:
|
||||
# We just need to begin an iteration to create the randomization of the sampler.
|
||||
for _ in train_dataloader:
|
||||
@ -1752,6 +1760,8 @@ class Trainer:
|
||||
total_batched_samples = 0
|
||||
for epoch in range(epochs_trained, num_train_epochs):
|
||||
epoch_iterator = train_dataloader
|
||||
if hasattr(epoch_iterator, "set_epoch"):
|
||||
epoch_iterator.set_epoch(epoch)
|
||||
|
||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||
if args.past_index >= 0:
|
||||
|
Loading…
Reference in New Issue
Block a user