mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove skipping logic now that set_epoch exists (#30501)
* Remove skipping logic now that set_epoch exists * Working version, clean
This commit is contained in:
parent
dfa7b580e9
commit
77ff304d29
@ -96,7 +96,6 @@ from .trainer_pt_utils import (
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
find_batch_size,
|
||||
get_dataloader_sampler,
|
||||
get_model_param_count,
|
||||
get_module_class_from_name,
|
||||
get_parameter_names,
|
||||
@ -2137,24 +2136,6 @@ class Trainer:
|
||||
|
||||
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
||||
|
||||
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
|
||||
if not args.ignore_data_skip:
|
||||
for epoch in range(epochs_trained):
|
||||
sampler = get_dataloader_sampler(train_dataloader)
|
||||
sampler_kinds = [RandomSampler]
|
||||
if version.parse(accelerate_version) > version.parse("0.23.0"):
|
||||
sampler_kinds.append(SeedableRandomSampler)
|
||||
is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
|
||||
if not is_random_sampler:
|
||||
# We just need to begin an iteration to create the randomization of the sampler.
|
||||
for _ in train_dataloader:
|
||||
break
|
||||
else:
|
||||
# Otherwise we need to call the whooooole sampler cause there is some random operation added
|
||||
# AT THE VERY END!
|
||||
sampler = sampler if sampler is not None else []
|
||||
_ = list(sampler)
|
||||
|
||||
total_batched_samples = 0
|
||||
for epoch in range(epochs_trained, num_train_epochs):
|
||||
epoch_iterator = train_dataloader
|
||||
|
Loading…
Reference in New Issue
Block a user