Remove skipping logic now that set_epoch exists (#30501)

* Remove skipping logic now that set_epoch exists

* Working version, clean
This commit is contained in:
Zach Mueller 2024-04-26 11:52:09 -04:00 committed by GitHub
parent dfa7b580e9
commit 77ff304d29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -96,7 +96,6 @@ from .trainer_pt_utils import (
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_dataloader_sampler,
get_model_param_count,
get_module_class_from_name,
get_parameter_names,
@ -2137,24 +2136,6 @@ class Trainer:
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
sampler = get_dataloader_sampler(train_dataloader)
sampler_kinds = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
sampler_kinds.append(SeedableRandomSampler)
is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
if not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
sampler = sampler if sampler is not None else []
_ = list(sampler)
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader