mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Fix hyperparameter search when optuna+deepseed (#34642)
* Fix hyperparameter search when optuna+deepseed * Adding free_memory to the search setup --------- Co-authored-by: Corentin-Royer <corentin.royer@ibm.com>
This commit is contained in:
parent
67890de3b8
commit
bf42c3bd4b
@ -208,7 +208,7 @@ def hp_params(trial):
|
|||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
if isinstance(trial, optuna.Trial):
|
if isinstance(trial, optuna.trial.BaseTrial):
|
||||||
return trial.params
|
return trial.params
|
||||||
if is_ray_tune_available():
|
if is_ray_tune_available():
|
||||||
if isinstance(trial, dict):
|
if isinstance(trial, dict):
|
||||||
@ -230,7 +230,7 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
|||||||
|
|
||||||
if trainer.args.process_index == 0:
|
if trainer.args.process_index == 0:
|
||||||
|
|
||||||
def _objective(trial, checkpoint_dir=None):
|
def _objective(trial: optuna.Trial, checkpoint_dir=None):
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if checkpoint_dir:
|
if checkpoint_dir:
|
||||||
for subdir in os.listdir(checkpoint_dir):
|
for subdir in os.listdir(checkpoint_dir):
|
||||||
@ -240,10 +240,11 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
|||||||
if trainer.args.world_size > 1:
|
if trainer.args.world_size > 1:
|
||||||
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
|
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
|
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
|
||||||
trainer._hp_search_setup(trial)
|
trainer.hp_space(trial)
|
||||||
args_main_rank_list = [pickle.dumps(trainer.args)]
|
fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number)
|
||||||
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
|
trial_main_rank_list = [fixed_trial]
|
||||||
trainer.train(resume_from_checkpoint=checkpoint)
|
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
||||||
else:
|
else:
|
||||||
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
||||||
# If there hasn't been any evaluation during the training loop.
|
# If there hasn't been any evaluation during the training loop.
|
||||||
@ -268,15 +269,11 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
|||||||
else:
|
else:
|
||||||
for i in range(n_trials):
|
for i in range(n_trials):
|
||||||
trainer.objective = None
|
trainer.objective = None
|
||||||
args_main_rank_list = [None]
|
trial_main_rank_list = [None]
|
||||||
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
|
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
|
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
|
||||||
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
|
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
|
||||||
args = pickle.loads(bytes(args_main_rank_list[0]))
|
trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0])
|
||||||
for key, value in asdict(args).items():
|
|
||||||
if key != "local_rank":
|
|
||||||
setattr(trainer.args, key, value)
|
|
||||||
trainer.train(resume_from_checkpoint=None)
|
|
||||||
# If there hasn't been any evaluation during the training loop.
|
# If there hasn't been any evaluation during the training loop.
|
||||||
if getattr(trainer, "objective", None) is None:
|
if getattr(trainer, "objective", None) is None:
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
@ -1725,6 +1725,9 @@ class Trainer:
|
|||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
if self.args.deepspeed is None:
|
if self.args.deepspeed is None:
|
||||||
raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
|
raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
|
||||||
|
|
||||||
|
self.accelerator.free_memory()
|
||||||
|
|
||||||
# Rebuild the deepspeed config to reflect the updated training parameters
|
# Rebuild the deepspeed config to reflect the updated training parameters
|
||||||
from accelerate.utils import DeepSpeedPlugin
|
from accelerate.utils import DeepSpeedPlugin
|
||||||
|
|
||||||
@ -1748,7 +1751,7 @@ class Trainer:
|
|||||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
if not trial.study._is_multi_objective():
|
if hasattr(trial, "study") and not trial.study._is_multi_objective():
|
||||||
trial.report(self.objective, step)
|
trial.report(self.objective, step)
|
||||||
if trial.should_prune():
|
if trial.should_prune():
|
||||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||||
|
Loading…
Reference in New Issue
Block a user