Skip batches fast with accelerate (#21390)

* Skip batches fast with Accelerate

* remove debug statement

* Hack seed reload at the right time

* Reorganize RNG sync

* Fix accelerate version comp
This commit is contained in:
Sylvain Gugger 2023-02-01 10:22:05 -05:00 committed by GitHub
parent 77db257e2a
commit 8d580779a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -138,6 +138,7 @@ from .utils import (
can_return_loss,
find_labels,
get_full_repo_name,
is_accelerate_available,
is_apex_available,
is_datasets_available,
is_in_notebook,
@ -193,6 +194,14 @@ else:
IS_SAGEMAKER_MP_POST_1_10 = False
skip_first_batches = None
if is_accelerate_available():
from accelerate import __version__ as accelerate_version
if version.parse(accelerate_version) >= version.parse("0.16"):
from accelerate import skip_first_batches
if TYPE_CHECKING:
import optuna
@ -1691,12 +1700,20 @@ class Trainer:
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
"flag to your launch command, but you will resume the training on data already seen by your model."
)
if self.is_local_process_zero() and not args.disable_tqdm:
if skip_first_batches is None:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
" you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
" also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
" training on data already seen by your model."
)
else:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")
@ -1772,8 +1789,17 @@ class Trainer:
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
for step, inputs in enumerate(epoch_iterator):
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0: