Fix hyperparameter search when optuna+deepseed (#34642)

* Fix hyperparameter search when optuna+deepseed

* Adding free_memory to the search setup

---------

Co-authored-by: Corentin-Royer <corentin.royer@ibm.com>
This commit is contained in:
Corentin Royer 2024-11-20 18:02:58 +01:00 committed by GitHub
parent 67890de3b8
commit bf42c3bd4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 14 deletions

View File

@ -208,7 +208,7 @@ def hp_params(trial):
if is_optuna_available():
import optuna
if isinstance(trial, optuna.Trial):
if isinstance(trial, optuna.trial.BaseTrial):
return trial.params
if is_ray_tune_available():
if isinstance(trial, dict):
@ -230,7 +230,7 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
if trainer.args.process_index == 0:
def _objective(trial, checkpoint_dir=None):
def _objective(trial: optuna.Trial, checkpoint_dir=None):
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
@ -240,10 +240,11 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
trainer._hp_search_setup(trial)
args_main_rank_list = [pickle.dumps(trainer.args)]
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=checkpoint)
trainer.hp_space(trial)
fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number)
trial_main_rank_list = [fixed_trial]
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
else:
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop.
@ -268,15 +269,11 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
else:
for i in range(n_trials):
trainer.objective = None
args_main_rank_list = [None]
trial_main_rank_list = [None]
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
args = pickle.loads(bytes(args_main_rank_list[0]))
for key, value in asdict(args).items():
if key != "local_rank":
setattr(trainer.args, key, value)
trainer.train(resume_from_checkpoint=None)
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0])
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()

View File

@ -1725,6 +1725,9 @@ class Trainer:
if self.is_deepspeed_enabled:
if self.args.deepspeed is None:
raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
self.accelerator.free_memory()
# Rebuild the deepspeed config to reflect the updated training parameters
from accelerate.utils import DeepSpeedPlugin
@ -1748,7 +1751,7 @@ class Trainer:
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna
if not trial.study._is_multi_objective():
if hasattr(trial, "study") and not trial.study._is_multi_objective():
trial.report(self.objective, step)
if trial.should_prune():
self.callback_handler.on_train_end(self.args, self.state, self.control)