Clarify batch size displayed when using DataParallel (#24430)

This commit is contained in:
Sylvain Gugger 2023-06-22 14:46:20 -04:00 committed by GitHub
parent b6295b26c5
commit 2834c17ad2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1671,7 +1671,9 @@ class Trainer:
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {self._train_batch_size:,}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
if self.args.per_device_train_batch_size != self._train_batch_size:
logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")