mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix hyperparameter search when optuna+deepseed (#34642)
* Fix hyperparameter search when optuna+deepseed * Adding free_memory to the search setup --------- Co-authored-by: Corentin-Royer <corentin.royer@ibm.com>
This commit is contained in:
parent
67890de3b8
commit
bf42c3bd4b
@ -208,7 +208,7 @@ def hp_params(trial):
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
|
||||
if isinstance(trial, optuna.Trial):
|
||||
if isinstance(trial, optuna.trial.BaseTrial):
|
||||
return trial.params
|
||||
if is_ray_tune_available():
|
||||
if isinstance(trial, dict):
|
||||
@ -230,7 +230,7 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
||||
|
||||
if trainer.args.process_index == 0:
|
||||
|
||||
def _objective(trial, checkpoint_dir=None):
|
||||
def _objective(trial: optuna.Trial, checkpoint_dir=None):
|
||||
checkpoint = None
|
||||
if checkpoint_dir:
|
||||
for subdir in os.listdir(checkpoint_dir):
|
||||
@ -240,10 +240,11 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
||||
if trainer.args.world_size > 1:
|
||||
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
|
||||
trainer._hp_search_setup(trial)
|
||||
args_main_rank_list = [pickle.dumps(trainer.args)]
|
||||
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.hp_space(trial)
|
||||
fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number)
|
||||
trial_main_rank_list = [fixed_trial]
|
||||
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
|
||||
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
@ -268,15 +269,11 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
||||
else:
|
||||
for i in range(n_trials):
|
||||
trainer.objective = None
|
||||
args_main_rank_list = [None]
|
||||
trial_main_rank_list = [None]
|
||||
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
|
||||
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
|
||||
args = pickle.loads(bytes(args_main_rank_list[0]))
|
||||
for key, value in asdict(args).items():
|
||||
if key != "local_rank":
|
||||
setattr(trainer.args, key, value)
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
|
||||
trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0])
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(trainer, "objective", None) is None:
|
||||
metrics = trainer.evaluate()
|
||||
|
@ -1725,6 +1725,9 @@ class Trainer:
|
||||
if self.is_deepspeed_enabled:
|
||||
if self.args.deepspeed is None:
|
||||
raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
|
||||
|
||||
self.accelerator.free_memory()
|
||||
|
||||
# Rebuild the deepspeed config to reflect the updated training parameters
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
@ -1748,7 +1751,7 @@ class Trainer:
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
import optuna
|
||||
|
||||
if not trial.study._is_multi_objective():
|
||||
if hasattr(trial, "study") and not trial.study._is_multi_objective():
|
||||
trial.report(self.objective, step)
|
||||
if trial.should_prune():
|
||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
|
Loading…
Reference in New Issue
Block a user