diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 1700301db51..dabc9a6ef95 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -458,11 +458,16 @@ def deepspeed_init(trainer, num_training_steps, inference=False): model_parameters = None else: trainer.optimizer = None # important for when deepspeed_init is used as re-init - tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 0) - if tp_size > 1: + deepspeed_tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1) + if deepspeed_tp_size > 1: import deepspeed - model = deepspeed.tp_model_init(model=model, tp_size=tp_size, dtype=hf_deepspeed_config.dtype()) + model = deepspeed.tp_model_init( + model=model, + tp_size=deepspeed_tp_size, + dtype=hf_deepspeed_config.dtype(), + config=hf_deepspeed_config.config, + ) model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) optimizer, lr_scheduler = deepspeed_optim_sched( trainer, hf_deepspeed_config, args, num_training_steps, model_parameters diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d7c567f5391..7044ea040bf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2238,6 +2238,27 @@ class Trainer: ignore_keys_for_eval=ignore_keys_for_eval, ) + def get_tp_size(self) -> int: + """Get the tensor parallel size from either the model or DeepSpeed config.""" + + # 1. Check model.tp_size first + if (model_tp := getattr(self.model, "_tp_size", None)) is not None: + return model_tp + + # 2. Fall back to DeepSpeed config if enabled + if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)): + return deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1) + + # 3. Default fallback + return 1 + + def get_total_train_batch_size(self, args) -> int: + """Calculates total batch size (micro_batch * grad_accum * dp_world_size). + + Note: Only considers DP and TP (dp_world_size = world_size // tp_size).""" + dp_world_size = args.world_size // self.get_tp_size() + return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size + def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): @@ -2268,7 +2289,8 @@ class Trainer: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps - total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size + total_train_batch_size = self.get_total_train_batch_size(args) + ( num_train_epochs, num_update_steps_per_epoch,