diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 03ace47468a..1a6f680d65a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1753,6 +1753,11 @@ class TrainingArguments: ) else: self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) + if self.accelerator_config.split_batches: + logger.info( + "Using `split_batches=True` in `accelerator_config` will override the `per_device_train_batch_size` " + "Batches will be split across all processes equally when using `split_batches=True`." + ) # Initialize device before we proceed if self.framework == "pt" and is_torch_available():