[Trainer] Add a progress bar for batches skipped (#11324)

This commit is contained in:
Sylvain Gugger 2021-04-19 19:04:52 -04:00 committed by GitHub
parent 95ffbe1686
commit 95037a169f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,6 +29,8 @@ from logging import StreamHandler
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from tqdm.auto import tqdm
# Integrations must be imported before ML frameworks:
from .integrations import ( # isort: split
@ -1097,6 +1099,7 @@ class Trainer:
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
@ -1116,8 +1119,12 @@ class Trainer:
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch."
"batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
"flag to your launch command, but you will resume the training on data already seen by your model."
)
if self.is_local_process_zero() and not args.disable_tqdm:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")
# Update the references
self.callback_handler.model = self.model
@ -1176,7 +1183,12 @@ class Trainer:
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)