mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
enable optuna multi-objectives feature (#25969)
* enable optuna multi-objectives feature Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update hpo doc * update docstring Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * extend direction to List[str] type Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * Update src/transformers/integrations/integration_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
92f2fbad50
commit
8f609ab9e0
@ -54,6 +54,18 @@ For optuna, see optuna [object_parameter](https://optuna.readthedocs.io/en/stabl
|
||||
... }
|
||||
```
|
||||
|
||||
Optuna provides multi-objective HPO. You can pass `direction` in `hyperparameter_search` and define your own compute_objective to return multiple objective values. The Pareto Front (`List[BestRun]`) will be returned in hyperparameter_search, you should refer to the test case `TrainerHyperParameterMultiObjectOptunaIntegrationTest` in [test_trainer](https://github.com/huggingface/transformers/blob/main/tests/trainer/test_trainer.py). It's like following
|
||||
|
||||
```py
|
||||
>>> best_trials = trainer.hyperparameter_search(
|
||||
... direction=["minimize", "maximize"],
|
||||
... backend="optuna",
|
||||
... hp_space=optuna_hp_space,
|
||||
... n_trials=20,
|
||||
... compute_objective=compute_objective,
|
||||
... )
|
||||
```
|
||||
|
||||
For raytune, see raytune [object_parameter](https://docs.ray.io/en/latest/tune/api/search_space.html), it's like following:
|
||||
|
||||
```py
|
||||
|
@ -205,10 +205,16 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
||||
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
n_jobs = kwargs.pop("n_jobs", 1)
|
||||
study = optuna.create_study(direction=direction, **kwargs)
|
||||
directions = direction if isinstance(direction, list) else None
|
||||
direction = None if directions is not None else direction
|
||||
study = optuna.create_study(direction=direction, directions=directions, **kwargs)
|
||||
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
|
||||
best_trial = study.best_trial
|
||||
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
|
||||
if not study._is_multi_objective():
|
||||
best_trial = study.best_trial
|
||||
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
|
||||
else:
|
||||
best_trials = study.best_trials
|
||||
return [BestRun(str(best.number), best.values, best.params) for best in best_trials]
|
||||
else:
|
||||
for i in range(n_trials):
|
||||
trainer.objective = None
|
||||
|
@ -1233,10 +1233,11 @@ class Trainer:
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
import optuna
|
||||
|
||||
trial.report(self.objective, step)
|
||||
if trial.should_prune():
|
||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
raise optuna.TrialPruned()
|
||||
if not trial.study._is_multi_objective():
|
||||
trial.report(self.objective, step)
|
||||
if trial.should_prune():
|
||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
raise optuna.TrialPruned()
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
from ray import tune
|
||||
|
||||
@ -2563,11 +2564,11 @@ class Trainer:
|
||||
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
||||
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
|
||||
n_trials: int = 20,
|
||||
direction: str = "minimize",
|
||||
direction: Union[str, List[str]] = "minimize",
|
||||
backend: Optional[Union["str", HPSearchBackend]] = None,
|
||||
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
|
||||
**kwargs,
|
||||
) -> BestRun:
|
||||
) -> Union[BestRun, List[BestRun]]:
|
||||
"""
|
||||
Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
|
||||
by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
|
||||
@ -2592,9 +2593,12 @@ class Trainer:
|
||||
method. Will default to [`~trainer_utils.default_compute_objective`].
|
||||
n_trials (`int`, *optional*, defaults to 100):
|
||||
The number of trial runs to test.
|
||||
direction (`str`, *optional*, defaults to `"minimize"`):
|
||||
Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
|
||||
`"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
|
||||
direction (`str` or `List[str]`, *optional*, defaults to `"minimize"`):
|
||||
If it's single objective optimization, direction is `str`, can be `"minimize"` or `"maximize"`, you
|
||||
should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or
|
||||
several metrics. If it's multi objectives optimization, direction is `List[str]`, can be List of
|
||||
`"minimize"` and `"maximize"`, you should pick `"minimize"` when optimizing the validation loss,
|
||||
`"maximize"` when optimizing one or several metrics.
|
||||
backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
|
||||
The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
|
||||
on which one is installed. If all are installed, will default to optuna.
|
||||
@ -2610,8 +2614,9 @@ class Trainer:
|
||||
- the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
|
||||
|
||||
Returns:
|
||||
[`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in
|
||||
`run_summary` attribute for Ray backend.
|
||||
[`trainer_utils.BestRun` or `List[trainer_utils.BestRun]`]: All the information about the best run or best
|
||||
runs for multi-objective optimization. Experiment summary can be found in `run_summary` attribute for Ray
|
||||
backend.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = default_hp_search_backend()
|
||||
|
@ -215,7 +215,7 @@ class BestRun(NamedTuple):
|
||||
"""
|
||||
|
||||
run_id: str
|
||||
objective: float
|
||||
objective: Union[float, List[float]]
|
||||
hyperparameters: Dict[str, Any]
|
||||
run_summary: Optional[Any] = None
|
||||
|
||||
|
@ -26,6 +26,7 @@ import tempfile
|
||||
import unittest
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
@ -2310,6 +2311,62 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
|
||||
trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_optuna
|
||||
class TrainerHyperParameterMultiObjectOptunaIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
args = TrainingArguments("..")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
def test_hyperparameter_search(self):
|
||||
class MyTrialShortNamer(TrialShortNamer):
|
||||
DEFAULTS = {"a": 0, "b": 0}
|
||||
|
||||
def hp_space(trial):
|
||||
return {}
|
||||
|
||||
def model_init(trial):
|
||||
if trial is not None:
|
||||
a = trial.suggest_int("a", -4, 4)
|
||||
b = trial.suggest_int("b", -4, 4)
|
||||
else:
|
||||
a = 0
|
||||
b = 0
|
||||
config = RegressionModelConfig(a=a, b=b, double_output=False)
|
||||
|
||||
return RegressionPreTrainedModel(config)
|
||||
|
||||
def hp_name(trial):
|
||||
return MyTrialShortNamer.shortname(trial.params)
|
||||
|
||||
def compute_objective(metrics: Dict[str, float]) -> List[float]:
|
||||
return metrics["eval_loss"], metrics["eval_accuracy"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1,
|
||||
logging_steps=1,
|
||||
evaluation_strategy=IntervalStrategy.EPOCH,
|
||||
save_strategy=IntervalStrategy.EPOCH,
|
||||
num_train_epochs=10,
|
||||
disable_tqdm=True,
|
||||
load_best_model_at_end=True,
|
||||
logging_dir="runs",
|
||||
run_name="test",
|
||||
model_init=model_init,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
trainer.hyperparameter_search(
|
||||
direction=["minimize", "maximize"],
|
||||
hp_space=hp_space,
|
||||
hp_name=hp_name,
|
||||
n_trials=4,
|
||||
compute_objective=compute_objective,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_ray
|
||||
class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user