mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Use object store to pass trainer object to Ray Tune (#9749)
This commit is contained in:
parent
6312fed47d
commit
d63ab61525
@ -149,20 +149,20 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
||||
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||
import ray
|
||||
|
||||
def _objective(trial, checkpoint_dir=None):
|
||||
def _objective(trial, local_trainer, checkpoint_dir=None):
|
||||
model_path = None
|
||||
if checkpoint_dir:
|
||||
for subdir in os.listdir(checkpoint_dir):
|
||||
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
|
||||
model_path = os.path.join(checkpoint_dir, subdir)
|
||||
trainer.objective = None
|
||||
trainer.train(model_path=model_path, trial=trial)
|
||||
local_trainer.objective = None
|
||||
local_trainer.train(model_path=model_path, trial=trial)
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(trainer, "objective", None) is None:
|
||||
metrics = trainer.evaluate()
|
||||
trainer.objective = trainer.compute_objective(metrics)
|
||||
trainer._tune_save_checkpoint()
|
||||
ray.tune.report(objective=trainer.objective, **metrics, done=True)
|
||||
if getattr(local_trainer, "objective", None) is None:
|
||||
metrics = local_trainer.evaluate()
|
||||
local_trainer.objective = local_trainer.compute_objective(metrics)
|
||||
local_trainer._tune_save_checkpoint()
|
||||
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
|
||||
|
||||
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
|
||||
# while doing the ray hp search.
|
||||
@ -217,7 +217,12 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
|
||||
)
|
||||
|
||||
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
|
||||
analysis = ray.tune.run(
|
||||
ray.tune.with_parameters(_objective, local_trainer=trainer),
|
||||
config=trainer.hp_space(None),
|
||||
num_samples=n_trials,
|
||||
**kwargs,
|
||||
)
|
||||
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
|
||||
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
|
||||
if _tb_writer is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user