Ray Tune Integration Bug Fixes (#10406)

* fixes

* update resources

* formatting

* remove import

* add log statement

* use fstring

* add period

* Update src/transformers/integrations.py
This commit is contained in:
Amog Kamsetty 2021-02-26 16:06:08 -08:00 committed by GitHub
parent 98569d4ba2
commit a85eb616f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 22 deletions

View File

@ -17,7 +17,6 @@ Integrations with other Python libraries.
import importlib.util
import io
import json
import math
import numbers
import os
import re
@ -174,16 +173,23 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
_tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None
# Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
# `args.n_gpu` is considered the total number of GPUs that will be split
# among the `n_jobs`
n_jobs = int(kwargs.pop("n_jobs", 1))
num_gpus_per_trial = trainer.args.n_gpu
if num_gpus_per_trial / n_jobs >= 1:
num_gpus_per_trial = int(math.ceil(num_gpus_per_trial / n_jobs))
kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial}
# Setup default `resources_per_trial`.
if "resources_per_trial" not in kwargs:
# Default to 1 CPU and 1 GPU (if applicable) per trial.
kwargs["resources_per_trial"] = {"cpu": 1}
if trainer.args.n_gpu > 0:
kwargs["resources_per_trial"]["gpu"] = 1
resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
logger.info(
"No `resources_per_trial` arg was passed into "
"`hyperparameter_search`. Setting it to a default value "
f"of {resource_msg} for each trial."
)
# Make sure each trainer only uses GPUs that were allocated per trial.
gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
trainer.args._n_gpu = gpus_per_trial
# Setup default `progress_reporter`.
if "progress_reporter" not in kwargs:
from ray.tune import CLIReporter
@ -193,7 +199,8 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
"Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, "
f"Currently keeping {kwargs['keep_checkpoint_num']} checkpoints for each trial. "
"Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
if "scheduler" in kwargs:

View File

@ -707,7 +707,7 @@ class Trainer:
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
if self.state.global_step % self.args.save_steps == 0:
if self.control.should_save:
self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics)
@ -717,8 +717,7 @@ class Trainer:
if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
self.args.output_dir = checkpoint_dir
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir)
if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
@ -1201,12 +1200,12 @@ class Trainer:
run_id = tune.get_trial_id()
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
run_dir = os.path.join(self.args.output_dir, run_name)
else:
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
run_dir = self.args.output_dir
self.store_flos()
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir)
if self.deepspeed:
self.deepspeed.save_checkpoint(output_dir)
@ -1250,7 +1249,7 @@ class Trainer:
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True)
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
@ -1559,10 +1558,12 @@ class Trainer:
else:
self.state.total_flos = self._total_flos
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
def _sorted_checkpoints(
self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
) -> List[str]:
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*")]
for path in glob_checkpoints:
if use_mtime:
@ -1583,12 +1584,12 @@ class Trainer:
)
return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False) -> None:
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
if len(checkpoints_sorted) <= self.args.save_total_limit:
return