fix total batch size calculation in trainer (#38286)

* fix total batch size calculation

* update

Signed-off-by: inkcherry <mingzhi.liu@intel.com>

* Update src/transformers/trainer.py

---------

Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
inkcherry 2025-06-06 22:54:00 +08:00 committed by GitHub
parent 02f946a038
commit 871901cb3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 4 deletions

View File

@ -458,11 +458,16 @@ def deepspeed_init(trainer, num_training_steps, inference=False):
model_parameters = None
else:
trainer.optimizer = None # important for when deepspeed_init is used as re-init
tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 0)
if tp_size > 1:
deepspeed_tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1)
if deepspeed_tp_size > 1:
import deepspeed
model = deepspeed.tp_model_init(model=model, tp_size=tp_size, dtype=hf_deepspeed_config.dtype())
model = deepspeed.tp_model_init(
model=model,
tp_size=deepspeed_tp_size,
dtype=hf_deepspeed_config.dtype(),
config=hf_deepspeed_config.config,
)
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer, lr_scheduler = deepspeed_optim_sched(
trainer, hf_deepspeed_config, args, num_training_steps, model_parameters

View File

@ -2238,6 +2238,27 @@ class Trainer:
ignore_keys_for_eval=ignore_keys_for_eval,
)
def get_tp_size(self) -> int:
"""Get the tensor parallel size from either the model or DeepSpeed config."""
# 1. Check model.tp_size first
if (model_tp := getattr(self.model, "_tp_size", None)) is not None:
return model_tp
# 2. Fall back to DeepSpeed config if enabled
if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)):
return deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1)
# 3. Default fallback
return 1
def get_total_train_batch_size(self, args) -> int:
"""Calculates total batch size (micro_batch * grad_accum * dp_world_size).
Note: Only considers DP and TP (dp_world_size = world_size // tp_size)."""
dp_world_size = args.world_size // self.get_tp_size()
return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size
def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
@ -2268,7 +2289,8 @@ class Trainer:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
total_train_batch_size = self.get_total_train_batch_size(args)
(
num_train_epochs,
num_update_steps_per_epoch,