mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Skip batches fast with accelerate (#21390)
* Skip batches fast with Accelerate * remove debug statement * Hack seed reload at the right time * Reorganize RNG sync * Fix accelerate version comp
This commit is contained in:
parent
77db257e2a
commit
8d580779a3
@ -138,6 +138,7 @@ from .utils import (
|
||||
can_return_loss,
|
||||
find_labels,
|
||||
get_full_repo_name,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
is_in_notebook,
|
||||
@ -193,6 +194,14 @@ else:
|
||||
IS_SAGEMAKER_MP_POST_1_10 = False
|
||||
|
||||
|
||||
skip_first_batches = None
|
||||
if is_accelerate_available():
|
||||
from accelerate import __version__ as accelerate_version
|
||||
|
||||
if version.parse(accelerate_version) >= version.parse("0.16"):
|
||||
from accelerate import skip_first_batches
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import optuna
|
||||
|
||||
@ -1691,12 +1700,20 @@ class Trainer:
|
||||
logger.info(f" Continuing training from epoch {epochs_trained}")
|
||||
logger.info(f" Continuing training from global step {self.state.global_step}")
|
||||
if not args.ignore_data_skip:
|
||||
logger.info(
|
||||
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
|
||||
"batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
|
||||
"flag to your launch command, but you will resume the training on data already seen by your model."
|
||||
)
|
||||
if self.is_local_process_zero() and not args.disable_tqdm:
|
||||
if skip_first_batches is None:
|
||||
logger.info(
|
||||
f" Will skip the first {epochs_trained} epochs then the first"
|
||||
f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
|
||||
" you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
|
||||
" also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
|
||||
" training on data already seen by your model."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f" Will skip the first {epochs_trained} epochs then the first"
|
||||
f" {steps_trained_in_current_epoch} batches in the first epoch."
|
||||
)
|
||||
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
|
||||
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
|
||||
steps_trained_progress_bar.set_description("Skipping the first batches")
|
||||
|
||||
@ -1772,8 +1789,17 @@ class Trainer:
|
||||
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
|
||||
rng_to_sync = False
|
||||
if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
|
||||
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
|
||||
steps_trained_in_current_epoch = 0
|
||||
rng_to_sync = True
|
||||
|
||||
step = -1
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
if rng_to_sync:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
rng_to_sync = False
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user