From d63ab61525d342c7af1bc3e58adad4f2936ead3a Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 25 Jan 2021 11:01:55 +0100 Subject: [PATCH] Use object store to pass trainer object to Ray Tune (#9749) --- src/transformers/integrations.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index af7909fca23..6a49635e6cf 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -149,20 +149,20 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: import ray - def _objective(trial, checkpoint_dir=None): + def _objective(trial, local_trainer, checkpoint_dir=None): model_path = None if checkpoint_dir: for subdir in os.listdir(checkpoint_dir): if subdir.startswith(PREFIX_CHECKPOINT_DIR): model_path = os.path.join(checkpoint_dir, subdir) - trainer.objective = None - trainer.train(model_path=model_path, trial=trial) + local_trainer.objective = None + local_trainer.train(model_path=model_path, trial=trial) # If there hasn't been any evaluation during the training loop. - if getattr(trainer, "objective", None) is None: - metrics = trainer.evaluate() - trainer.objective = trainer.compute_objective(metrics) - trainer._tune_save_checkpoint() - ray.tune.report(objective=trainer.objective, **metrics, done=True) + if getattr(local_trainer, "objective", None) is None: + metrics = local_trainer.evaluate() + local_trainer.objective = local_trainer.compute_objective(metrics) + local_trainer._tune_save_checkpoint() + ray.tune.report(objective=local_trainer.objective, **metrics, done=True) # The model and TensorBoard writer do not pickle so we have to remove them (if they exists) # while doing the ray hp search. @@ -217,7 +217,12 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__) ) - analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs) + analysis = ray.tune.run( + ray.tune.with_parameters(_objective, local_trainer=trainer), + config=trainer.hp_space(None), + num_samples=n_trials, + **kwargs, + ) best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3]) best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config) if _tb_writer is not None: