mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
fix total batch size calculation in trainer (#38286)
* fix total batch size calculation * update Signed-off-by: inkcherry <mingzhi.liu@intel.com> * Update src/transformers/trainer.py --------- Signed-off-by: inkcherry <mingzhi.liu@intel.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
02f946a038
commit
871901cb3d
@ -458,11 +458,16 @@ def deepspeed_init(trainer, num_training_steps, inference=False):
|
||||
model_parameters = None
|
||||
else:
|
||||
trainer.optimizer = None # important for when deepspeed_init is used as re-init
|
||||
tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 0)
|
||||
if tp_size > 1:
|
||||
deepspeed_tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1)
|
||||
if deepspeed_tp_size > 1:
|
||||
import deepspeed
|
||||
|
||||
model = deepspeed.tp_model_init(model=model, tp_size=tp_size, dtype=hf_deepspeed_config.dtype())
|
||||
model = deepspeed.tp_model_init(
|
||||
model=model,
|
||||
tp_size=deepspeed_tp_size,
|
||||
dtype=hf_deepspeed_config.dtype(),
|
||||
config=hf_deepspeed_config.config,
|
||||
)
|
||||
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
|
||||
optimizer, lr_scheduler = deepspeed_optim_sched(
|
||||
trainer, hf_deepspeed_config, args, num_training_steps, model_parameters
|
||||
|
@ -2238,6 +2238,27 @@ class Trainer:
|
||||
ignore_keys_for_eval=ignore_keys_for_eval,
|
||||
)
|
||||
|
||||
def get_tp_size(self) -> int:
|
||||
"""Get the tensor parallel size from either the model or DeepSpeed config."""
|
||||
|
||||
# 1. Check model.tp_size first
|
||||
if (model_tp := getattr(self.model, "_tp_size", None)) is not None:
|
||||
return model_tp
|
||||
|
||||
# 2. Fall back to DeepSpeed config if enabled
|
||||
if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)):
|
||||
return deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1)
|
||||
|
||||
# 3. Default fallback
|
||||
return 1
|
||||
|
||||
def get_total_train_batch_size(self, args) -> int:
|
||||
"""Calculates total batch size (micro_batch * grad_accum * dp_world_size).
|
||||
|
||||
Note: Only considers DP and TP (dp_world_size = world_size // tp_size)."""
|
||||
dp_world_size = args.world_size // self.get_tp_size()
|
||||
return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size
|
||||
|
||||
def _inner_training_loop(
|
||||
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
|
||||
):
|
||||
@ -2268,7 +2289,8 @@ class Trainer:
|
||||
# number of training epochs: num_train_epochs
|
||||
# number of training steps per epoch: num_update_steps_per_epoch
|
||||
# total number of training steps to execute: max_steps
|
||||
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
|
||||
total_train_batch_size = self.get_total_train_batch_size(args)
|
||||
|
||||
(
|
||||
num_train_epochs,
|
||||
num_update_steps_per_epoch,
|
||||
|
Loading…
Reference in New Issue
Block a user