add DDP HPO support for sigopt (#18931)

only main_process will have HPO, and pass argument to other process

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2022-09-12 19:37:25 +08:00 committed by GitHub
parent 9faa9f9dac
commit a86acb75ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,6 +19,7 @@ import importlib.util
import json
import numbers
import os
import pickle
import shutil
import sys
import tempfile
@ -28,11 +29,13 @@ from typing import TYPE_CHECKING, Dict, Optional
import numpy as np
from . import __version__ as version
from .utils import flatten_dict, is_datasets_available, logging
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
# comet_ml requires to be imported before any ML frameworks
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
@ -55,6 +58,7 @@ if TYPE_CHECKING and _has_neptune:
from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from .training_args import ParallelMode # noqa: E402
from .utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
@ -317,67 +321,94 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
import sigopt
from transformers.utils.versions import importlib_metadata
if importlib_metadata.version("sigopt") >= "8.0.0":
sigopt.set_project("huggingface")
if trainer.args.process_index == 0:
if importlib_metadata.version("sigopt") >= "8.0.0":
sigopt.set_project("huggingface")
experiment = sigopt.create_experiment(
name="huggingface-tune",
type="offline",
parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
parallel_bandwidth=1,
budget=n_trials,
)
experiment = sigopt.create_experiment(
name="huggingface-tune",
type="offline",
parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
parallel_bandwidth=1,
budget=n_trials,
)
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
for run in experiment.loop():
with run:
for run in experiment.loop():
with run:
trainer.objective = None
trainer._hp_search_setup(run.run)
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
trainer.train(resume_from_checkpoint=None)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
run.log_metric("objective", trainer.objective)
best = list(experiment.get_best_runs())[0]
best_run = BestRun(best.id, best.values["objective"].value, best.assignments)
else:
from sigopt import Connection
conn = Connection()
proxies = kwargs.pop("proxies", None)
if proxies is not None:
conn.set_proxies(proxies)
experiment = conn.experiments().create(
name="huggingface-tune",
parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
parallel_bandwidth=1,
observation_budget=n_trials,
project="huggingface",
)
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
while experiment.progress.observation_count < experiment.observation_budget:
suggestion = conn.experiments(experiment.id).suggestions().create()
trainer.objective = None
trainer.train(resume_from_checkpoint=None, trial=run.run)
trainer._hp_search_setup(suggestion)
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
trainer.train(resume_from_checkpoint=None)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
run.log_metric("objective", trainer.objective)
best = list(experiment.get_best_runs())[0]
best_run = BestRun(best.id, best.values["objective"].value, best.assignments)
values = [dict(name="objective", value=trainer.objective)]
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
experiment = conn.experiments(experiment.id).fetch()
best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
best_run = BestRun(best.id, best.value, best.assignments)
return best_run
else:
from sigopt import Connection
conn = Connection()
proxies = kwargs.pop("proxies", None)
if proxies is not None:
conn.set_proxies(proxies)
experiment = conn.experiments().create(
name="huggingface-tune",
parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
parallel_bandwidth=1,
observation_budget=n_trials,
project="huggingface",
)
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
while experiment.progress.observation_count < experiment.observation_budget:
suggestion = conn.experiments(experiment.id).suggestions().create()
for i in range(n_trials):
trainer.objective = None
trainer.train(resume_from_checkpoint=None, trial=suggestion)
args_main_rank = list(pickle.dumps(trainer.args))
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(args_main_rank, src=0)
local_rank = trainer.args.local_rank # backup the local_rank info
trainer.args = pickle.loads(bytes(args_main_rank))
trainer.args.local_rank = local_rank
trainer.train(resume_from_checkpoint=None)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
values = [dict(name="objective", value=trainer.objective)]
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
experiment = conn.experiments(experiment.id).fetch()
best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
best_run = BestRun(best.id, best.value, best.assignments)
return best_run
return None
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: