Added additional kwarg for successful running of optuna hyperparameter search (#31924)

Update integration_utils.py

Added additional kwarg
This commit is contained in:
Deep Gandhi 2024-07-23 19:11:52 +05:30 committed by GitHub
parent 63700628ad
commit 7d92009af6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -252,10 +252,11 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
gc_after_trial = kwargs.pop("gc_after_trial", False)
directions = direction if isinstance(direction, list) else None
direction = None if directions is not None else direction
study = optuna.create_study(direction=direction, directions=directions, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs, gc_after_trial=gc_after_trial)
if not study._is_multi_objective():
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)