Use object store to pass trainer object to Ray Tune (#9749)

This commit is contained in:
Kai Fricke 2021-01-25 11:01:55 +01:00 committed by GitHub
parent 6312fed47d
commit d63ab61525
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -149,20 +149,20 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray
def _objective(trial, checkpoint_dir=None):
def _objective(trial, local_trainer, checkpoint_dir=None):
model_path = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
model_path = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
trainer.train(model_path=model_path, trial=trial)
local_trainer.objective = None
local_trainer.train(model_path=model_path, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
trainer._tune_save_checkpoint()
ray.tune.report(objective=trainer.objective, **metrics, done=True)
if getattr(local_trainer, "objective", None) is None:
metrics = local_trainer.evaluate()
local_trainer.objective = local_trainer.compute_objective(metrics)
local_trainer._tune_save_checkpoint()
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
@ -217,7 +217,12 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
)
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
analysis = ray.tune.run(
ray.tune.with_parameters(_objective, local_trainer=trainer),
config=trainer.hp_space(None),
num_samples=n_trials,
**kwargs,
)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
if _tb_writer is not None: