mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[Trainer] Add a progress bar for batches skipped (#11324)
This commit is contained in:
parent
95ffbe1686
commit
95037a169f
@ -29,6 +29,8 @@ from logging import StreamHandler
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
# Integrations must be imported before ML frameworks:
|
||||
from .integrations import ( # isort: split
|
||||
@ -1097,6 +1099,7 @@ class Trainer:
|
||||
start_time = time.time()
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
steps_trained_progress_bar = None
|
||||
|
||||
# Check if continuing training from a checkpoint
|
||||
if resume_from_checkpoint is not None and os.path.isfile(
|
||||
@ -1116,8 +1119,12 @@ class Trainer:
|
||||
if not args.ignore_data_skip:
|
||||
logger.info(
|
||||
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
|
||||
"batches in the first epoch."
|
||||
"batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
|
||||
"flag to your launch command, but you will resume the training on data already seen by your model."
|
||||
)
|
||||
if self.is_local_process_zero() and not args.disable_tqdm:
|
||||
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
|
||||
steps_trained_progress_bar.set_description("Skipping the first batches")
|
||||
|
||||
# Update the references
|
||||
self.callback_handler.model = self.model
|
||||
@ -1176,7 +1183,12 @@ class Trainer:
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
if steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.update(1)
|
||||
continue
|
||||
elif steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.close()
|
||||
steps_trained_progress_bar = None
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
|
||||
|
Loading…
Reference in New Issue
Block a user