mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add W&B backend for hyperparameter sweep (#14582)
# Add support for W&B hyperparameter sweep This PR: * allows using wandb for running hyperparameter search. * The runs are visualized on W&B sweeps dashboard * This supports runnning sweeps on parallel devices, all reporting to the same central dashboard. ### Usage **To run new a hyperparameter search:** ``` trainer.hyperparameter_search( backend="wandb", project="transformers_sweep", # name of the project n_trials=5, metric="eval/loss", # metric to be optimized, default 'eval/loss'. A warning is raised if the passed metric is not found ) ``` This outputs a sweep id. Eg. `my_project/sweep_id` **To run sweeps on parallel devices:** Just pass sweep id which you want to run parallel ``` trainer.hyperparameter_search( backend="wandb", sweep_id = "my_project/sweep_id" ) ```
This commit is contained in:
parent
13297ac71c
commit
c74f3d4c48
2
.github/workflows/self-scheduled.yml
vendored
2
.github/workflows/self-scheduled.yml
vendored
@ -38,6 +38,7 @@ jobs:
|
||||
pip install .[integrations,sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm]
|
||||
pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
|
||||
wandb login ${{ secrets.WANDB_API_KEY }}
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
run: |
|
||||
@ -271,6 +272,7 @@ jobs:
|
||||
pip install .[integrations,sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm]
|
||||
pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
|
||||
wandb login ${{ secrets.WANDB_API_KEY }}
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
run: |
|
||||
|
@ -125,6 +125,10 @@ def hp_params(trial):
|
||||
if isinstance(trial, dict):
|
||||
return trial
|
||||
|
||||
if is_wandb_available():
|
||||
if isinstance(trial, dict):
|
||||
return trial
|
||||
|
||||
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
|
||||
|
||||
|
||||
@ -337,6 +341,75 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
||||
return best_run
|
||||
|
||||
|
||||
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||
from .integrations import is_wandb_available
|
||||
|
||||
if not is_wandb_available():
|
||||
raise ImportError("This function needs wandb installed: `pip install wandb`")
|
||||
import wandb
|
||||
|
||||
# add WandbCallback if not already added in trainer callbacks
|
||||
reporting_to_wandb = False
|
||||
for callback in trainer.callback_handler.callbacks:
|
||||
if isinstance(callback, WandbCallback):
|
||||
reporting_to_wandb = True
|
||||
break
|
||||
if not reporting_to_wandb:
|
||||
trainer.add_callback(WandbCallback())
|
||||
trainer.args.report_to = "wandb"
|
||||
best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
|
||||
sweep_id = kwargs.pop("sweep_id", None)
|
||||
project = kwargs.pop("project", None)
|
||||
name = kwargs.pop("name", None)
|
||||
entity = kwargs.pop("entity", None)
|
||||
metric = kwargs.pop("metric", "eval/loss")
|
||||
|
||||
sweep_config = trainer.hp_space(None)
|
||||
sweep_config["metric"]["goal"] = direction
|
||||
sweep_config["metric"]["name"] = metric
|
||||
if name:
|
||||
sweep_config["name"] = name
|
||||
|
||||
def _objective():
|
||||
|
||||
run = wandb.run if wandb.run else wandb.init()
|
||||
trainer.state.trial_name = run.name
|
||||
run.config.update({"assignments": {}, "metric": metric})
|
||||
config = wandb.config
|
||||
|
||||
trainer.objective = None
|
||||
|
||||
trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(trainer, "objective", None) is None:
|
||||
metrics = trainer.evaluate()
|
||||
trainer.objective = trainer.compute_objective(metrics)
|
||||
format_metrics = rewrite_logs(metrics)
|
||||
if metric not in format_metrics:
|
||||
logger.warning(
|
||||
f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available metrics are {format_metrics.keys()}"
|
||||
)
|
||||
best_score = False
|
||||
if best_trial["run_id"] is not None:
|
||||
if direction == "minimize":
|
||||
best_score = trainer.objective < best_trial["objective"]
|
||||
elif direction == "maximize":
|
||||
best_score = trainer.objective > best_trial["objective"]
|
||||
|
||||
if best_score or best_trial["run_id"] is None:
|
||||
best_trial["run_id"] = run.id
|
||||
best_trial["objective"] = trainer.objective
|
||||
best_trial["hyperparameters"] = dict(config)
|
||||
|
||||
return trainer.objective
|
||||
|
||||
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
|
||||
logger.info(f"wandb sweep id - {sweep_id}")
|
||||
wandb.agent(sweep_id, function=_objective, count=n_trials)
|
||||
|
||||
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])
|
||||
|
||||
|
||||
def get_available_reporting_integrations():
|
||||
integrations = []
|
||||
if is_azureml_available():
|
||||
@ -542,6 +615,7 @@ class WandbCallback(TrainerCallback):
|
||||
if hp_search:
|
||||
self._wandb.finish()
|
||||
self._initialized = False
|
||||
args.run_name = None
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model, **kwargs)
|
||||
|
||||
|
@ -59,7 +59,7 @@ from .file_utils import (
|
||||
is_torchaudio_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available
|
||||
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available
|
||||
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
@ -590,6 +590,19 @@ def require_sigopt(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_wandb(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires wandb.
|
||||
|
||||
These tests are skipped when wandb isn't installed.
|
||||
|
||||
"""
|
||||
if not is_wandb_available():
|
||||
return unittest.skip("test requires wandb")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_soundfile(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires soundfile
|
||||
|
@ -43,9 +43,11 @@ from .integrations import ( # isort: split
|
||||
is_optuna_available,
|
||||
is_ray_tune_available,
|
||||
is_sigopt_available,
|
||||
is_wandb_available,
|
||||
run_hp_search_optuna,
|
||||
run_hp_search_ray,
|
||||
run_hp_search_sigopt,
|
||||
run_hp_search_wandb,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@ -937,6 +939,8 @@ class Trainer:
|
||||
params.pop("wandb", None)
|
||||
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
|
||||
params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
|
||||
elif self.hp_search_backend == HPSearchBackend.WANDB:
|
||||
params = trial
|
||||
|
||||
for key, value in params.items():
|
||||
if not hasattr(self.args, key):
|
||||
@ -953,6 +957,8 @@ class Trainer:
|
||||
logger.info("Trial:", trial.params)
|
||||
if self.hp_search_backend == HPSearchBackend.SIGOPT:
|
||||
logger.info(f"SigOpt Assignments: {trial.assignments}")
|
||||
if self.hp_search_backend == HPSearchBackend.WANDB:
|
||||
logger.info(f"W&B Sweep parameters: {trial}")
|
||||
if self.args.deepspeed:
|
||||
# Rebuild the deepspeed config to reflect the updated training parameters
|
||||
from transformers.deepspeed import HfDeepSpeedConfig
|
||||
@ -1646,6 +1652,10 @@ class Trainer:
|
||||
run_id = tune.get_trial_id()
|
||||
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
|
||||
run_id = trial.id
|
||||
elif self.hp_search_backend == HPSearchBackend.WANDB:
|
||||
import wandb
|
||||
|
||||
run_id = wandb.run.id
|
||||
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||
run_dir = os.path.join(self.args.output_dir, run_name)
|
||||
else:
|
||||
@ -1848,6 +1858,8 @@ class Trainer:
|
||||
)
|
||||
if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
|
||||
raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
|
||||
if backend == HPSearchBackend.WANDB and not is_wandb_available():
|
||||
raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
|
||||
self.hp_search_backend = backend
|
||||
if self.model_init is None:
|
||||
raise RuntimeError(
|
||||
@ -1862,6 +1874,7 @@ class Trainer:
|
||||
HPSearchBackend.OPTUNA: run_hp_search_optuna,
|
||||
HPSearchBackend.RAY: run_hp_search_ray,
|
||||
HPSearchBackend.SIGOPT: run_hp_search_sigopt,
|
||||
HPSearchBackend.WANDB: run_hp_search_wandb,
|
||||
}
|
||||
best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
|
||||
|
||||
|
@ -210,16 +210,36 @@ def default_hp_space_sigopt(trial):
|
||||
]
|
||||
|
||||
|
||||
def default_hp_space_wandb(trial) -> Dict[str, float]:
|
||||
from .integrations import is_wandb_available
|
||||
|
||||
if not is_wandb_available():
|
||||
raise ImportError("This function needs wandb installed: `pip install wandb`")
|
||||
|
||||
return {
|
||||
"method": "random",
|
||||
"metric": {"name": "objective", "goal": "minimize"},
|
||||
"parameters": {
|
||||
"learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
|
||||
"num_train_epochs": {"distribution": "int_uniform", "min": 1, "max": 6},
|
||||
"seed": {"distribution": "int_uniform", "min": 1, "max": 40},
|
||||
"per_device_train_batch_size": {"values": [4, 8, 16, 32, 64]},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class HPSearchBackend(ExplicitEnum):
|
||||
OPTUNA = "optuna"
|
||||
RAY = "ray"
|
||||
SIGOPT = "sigopt"
|
||||
WANDB = "wandb"
|
||||
|
||||
|
||||
default_hp_space = {
|
||||
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
||||
HPSearchBackend.RAY: default_hp_space_ray,
|
||||
HPSearchBackend.SIGOPT: default_hp_space_sigopt,
|
||||
HPSearchBackend.WANDB: default_hp_space_wandb,
|
||||
}
|
||||
|
||||
|
||||
|
@ -60,6 +60,7 @@ from transformers.testing_utils import (
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
require_wandb,
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
@ -1810,3 +1811,59 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
|
||||
with patch.dict("sys.modules", {"apex.optimizers": None}):
|
||||
with self.assertRaises(ValueError):
|
||||
Trainer.get_optimizer_cls_and_kwargs(args)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_wandb
|
||||
class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
args = TrainingArguments(".")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
def test_hyperparameter_search(self):
|
||||
class MyTrialShortNamer(TrialShortNamer):
|
||||
DEFAULTS = {"a": 0, "b": 0}
|
||||
|
||||
def hp_space(trial):
|
||||
|
||||
return {
|
||||
"method": "random",
|
||||
"metric": {},
|
||||
"parameters": {
|
||||
"a": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
|
||||
"b": {"distribution": "int_uniform", "min": 1, "max": 6},
|
||||
},
|
||||
}
|
||||
|
||||
def model_init(config):
|
||||
if config is None:
|
||||
a = 0
|
||||
b = 0
|
||||
else:
|
||||
a = config["a"]
|
||||
b = config["b"]
|
||||
model_config = RegressionModelConfig(a=a, b=b, double_output=False)
|
||||
|
||||
return RegressionPreTrainedModel(model_config)
|
||||
|
||||
def hp_name(params):
|
||||
return MyTrialShortNamer.shortname(params)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1,
|
||||
logging_steps=1,
|
||||
evaluation_strategy=IntervalStrategy.EPOCH,
|
||||
save_strategy=IntervalStrategy.EPOCH,
|
||||
num_train_epochs=4,
|
||||
disable_tqdm=True,
|
||||
load_best_model_at_end=True,
|
||||
logging_dir="runs",
|
||||
run_name="test",
|
||||
model_init=model_init,
|
||||
)
|
||||
trainer.hyperparameter_search(
|
||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user