mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
add DDP HPO support for sigopt (#18931)
only main_process will have HPO, and pass argument to other process Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9faa9f9dac
commit
a86acb75ad
@ -19,6 +19,7 @@ import importlib.util
|
||||
import json
|
||||
import numbers
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
@ -28,11 +29,13 @@ from typing import TYPE_CHECKING, Dict, Optional
|
||||
import numpy as np
|
||||
|
||||
from . import __version__ as version
|
||||
from .utils import flatten_dict, is_datasets_available, logging
|
||||
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
# comet_ml requires to be imported before any ML frameworks
|
||||
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
|
||||
@ -55,6 +58,7 @@ if TYPE_CHECKING and _has_neptune:
|
||||
|
||||
from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
|
||||
from .training_args import ParallelMode # noqa: E402
|
||||
from .utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
|
||||
|
||||
|
||||
@ -317,67 +321,94 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
||||
import sigopt
|
||||
from transformers.utils.versions import importlib_metadata
|
||||
|
||||
if importlib_metadata.version("sigopt") >= "8.0.0":
|
||||
sigopt.set_project("huggingface")
|
||||
if trainer.args.process_index == 0:
|
||||
if importlib_metadata.version("sigopt") >= "8.0.0":
|
||||
sigopt.set_project("huggingface")
|
||||
|
||||
experiment = sigopt.create_experiment(
|
||||
name="huggingface-tune",
|
||||
type="offline",
|
||||
parameters=trainer.hp_space(None),
|
||||
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
|
||||
parallel_bandwidth=1,
|
||||
budget=n_trials,
|
||||
)
|
||||
experiment = sigopt.create_experiment(
|
||||
name="huggingface-tune",
|
||||
type="offline",
|
||||
parameters=trainer.hp_space(None),
|
||||
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
|
||||
parallel_bandwidth=1,
|
||||
budget=n_trials,
|
||||
)
|
||||
|
||||
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
|
||||
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
|
||||
|
||||
for run in experiment.loop():
|
||||
with run:
|
||||
for run in experiment.loop():
|
||||
with run:
|
||||
trainer.objective = None
|
||||
trainer._hp_search_setup(run.run)
|
||||
if trainer.args.world_size > 1:
|
||||
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
|
||||
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(trainer, "objective", None) is None:
|
||||
metrics = trainer.evaluate()
|
||||
trainer.objective = trainer.compute_objective(metrics)
|
||||
run.log_metric("objective", trainer.objective)
|
||||
|
||||
best = list(experiment.get_best_runs())[0]
|
||||
best_run = BestRun(best.id, best.values["objective"].value, best.assignments)
|
||||
else:
|
||||
from sigopt import Connection
|
||||
|
||||
conn = Connection()
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
if proxies is not None:
|
||||
conn.set_proxies(proxies)
|
||||
|
||||
experiment = conn.experiments().create(
|
||||
name="huggingface-tune",
|
||||
parameters=trainer.hp_space(None),
|
||||
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
|
||||
parallel_bandwidth=1,
|
||||
observation_budget=n_trials,
|
||||
project="huggingface",
|
||||
)
|
||||
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
|
||||
|
||||
while experiment.progress.observation_count < experiment.observation_budget:
|
||||
suggestion = conn.experiments(experiment.id).suggestions().create()
|
||||
trainer.objective = None
|
||||
trainer.train(resume_from_checkpoint=None, trial=run.run)
|
||||
trainer._hp_search_setup(suggestion)
|
||||
if trainer.args.world_size > 1:
|
||||
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
|
||||
torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(trainer, "objective", None) is None:
|
||||
metrics = trainer.evaluate()
|
||||
trainer.objective = trainer.compute_objective(metrics)
|
||||
run.log_metric("objective", trainer.objective)
|
||||
|
||||
best = list(experiment.get_best_runs())[0]
|
||||
best_run = BestRun(best.id, best.values["objective"].value, best.assignments)
|
||||
values = [dict(name="objective", value=trainer.objective)]
|
||||
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
|
||||
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
|
||||
experiment = conn.experiments(experiment.id).fetch()
|
||||
|
||||
best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
|
||||
best_run = BestRun(best.id, best.value, best.assignments)
|
||||
return best_run
|
||||
else:
|
||||
from sigopt import Connection
|
||||
|
||||
conn = Connection()
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
if proxies is not None:
|
||||
conn.set_proxies(proxies)
|
||||
|
||||
experiment = conn.experiments().create(
|
||||
name="huggingface-tune",
|
||||
parameters=trainer.hp_space(None),
|
||||
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
|
||||
parallel_bandwidth=1,
|
||||
observation_budget=n_trials,
|
||||
project="huggingface",
|
||||
)
|
||||
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
|
||||
|
||||
while experiment.progress.observation_count < experiment.observation_budget:
|
||||
suggestion = conn.experiments(experiment.id).suggestions().create()
|
||||
for i in range(n_trials):
|
||||
trainer.objective = None
|
||||
trainer.train(resume_from_checkpoint=None, trial=suggestion)
|
||||
args_main_rank = list(pickle.dumps(trainer.args))
|
||||
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
|
||||
torch.distributed.broadcast_object_list(args_main_rank, src=0)
|
||||
local_rank = trainer.args.local_rank # backup the local_rank info
|
||||
trainer.args = pickle.loads(bytes(args_main_rank))
|
||||
trainer.args.local_rank = local_rank
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(trainer, "objective", None) is None:
|
||||
metrics = trainer.evaluate()
|
||||
trainer.objective = trainer.compute_objective(metrics)
|
||||
|
||||
values = [dict(name="objective", value=trainer.objective)]
|
||||
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
|
||||
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
|
||||
experiment = conn.experiments(experiment.id).fetch()
|
||||
|
||||
best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
|
||||
best_run = BestRun(best.id, best.value, best.assignments)
|
||||
return best_run
|
||||
return None
|
||||
|
||||
|
||||
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||
|
Loading…
Reference in New Issue
Block a user