mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Ray Tune Integration Bug Fixes (#10406)
* fixes * update resources * formatting * remove import * add log statement * use fstring * add period * Update src/transformers/integrations.py
This commit is contained in:
parent
98569d4ba2
commit
a85eb616f7
@ -17,7 +17,6 @@ Integrations with other Python libraries.
|
||||
import importlib.util
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import numbers
|
||||
import os
|
||||
import re
|
||||
@ -174,16 +173,23 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
|
||||
_tb_writer = trainer.pop_callback(TensorBoardCallback)
|
||||
trainer.model = None
|
||||
# Setup default `resources_per_trial` and `reporter`.
|
||||
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
|
||||
# `args.n_gpu` is considered the total number of GPUs that will be split
|
||||
# among the `n_jobs`
|
||||
n_jobs = int(kwargs.pop("n_jobs", 1))
|
||||
num_gpus_per_trial = trainer.args.n_gpu
|
||||
if num_gpus_per_trial / n_jobs >= 1:
|
||||
num_gpus_per_trial = int(math.ceil(num_gpus_per_trial / n_jobs))
|
||||
kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial}
|
||||
# Setup default `resources_per_trial`.
|
||||
if "resources_per_trial" not in kwargs:
|
||||
# Default to 1 CPU and 1 GPU (if applicable) per trial.
|
||||
kwargs["resources_per_trial"] = {"cpu": 1}
|
||||
if trainer.args.n_gpu > 0:
|
||||
kwargs["resources_per_trial"]["gpu"] = 1
|
||||
resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
|
||||
logger.info(
|
||||
"No `resources_per_trial` arg was passed into "
|
||||
"`hyperparameter_search`. Setting it to a default value "
|
||||
f"of {resource_msg} for each trial."
|
||||
)
|
||||
# Make sure each trainer only uses GPUs that were allocated per trial.
|
||||
gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
|
||||
trainer.args._n_gpu = gpus_per_trial
|
||||
|
||||
# Setup default `progress_reporter`.
|
||||
if "progress_reporter" not in kwargs:
|
||||
from ray.tune import CLIReporter
|
||||
|
||||
@ -193,7 +199,8 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
trainer.use_tune_checkpoints = True
|
||||
if kwargs["keep_checkpoints_num"] > 1:
|
||||
logger.warning(
|
||||
"Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, "
|
||||
f"Currently keeping {kwargs['keep_checkpoint_num']} checkpoints for each trial. "
|
||||
"Checkpoints are usually huge, "
|
||||
"consider setting `keep_checkpoints_num=1`."
|
||||
)
|
||||
if "scheduler" in kwargs:
|
||||
|
@ -707,7 +707,7 @@ class Trainer:
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
from ray import tune
|
||||
|
||||
if self.state.global_step % self.args.save_steps == 0:
|
||||
if self.control.should_save:
|
||||
self._tune_save_checkpoint()
|
||||
tune.report(objective=self.objective, **metrics)
|
||||
|
||||
@ -717,8 +717,7 @@ class Trainer:
|
||||
if not self.use_tune_checkpoints:
|
||||
return
|
||||
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
||||
self.args.output_dir = checkpoint_dir
|
||||
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||
self.save_model(output_dir)
|
||||
if self.is_world_process_zero():
|
||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||
@ -1201,12 +1200,12 @@ class Trainer:
|
||||
|
||||
run_id = tune.get_trial_id()
|
||||
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
|
||||
run_dir = os.path.join(self.args.output_dir, run_name)
|
||||
else:
|
||||
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
||||
|
||||
run_dir = self.args.output_dir
|
||||
self.store_flos()
|
||||
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
self.save_model(output_dir)
|
||||
if self.deepspeed:
|
||||
self.deepspeed.save_checkpoint(output_dir)
|
||||
@ -1250,7 +1249,7 @@ class Trainer:
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.is_world_process_zero():
|
||||
self._rotate_checkpoints(use_mtime=True)
|
||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||
|
||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||
"""If optimizer and scheduler states exist, load them."""
|
||||
@ -1559,10 +1558,12 @@ class Trainer:
|
||||
else:
|
||||
self.state.total_flos = self._total_flos
|
||||
|
||||
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
||||
def _sorted_checkpoints(
|
||||
self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
|
||||
) -> List[str]:
|
||||
ordering_and_checkpoint_path = []
|
||||
|
||||
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
|
||||
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*")]
|
||||
|
||||
for path in glob_checkpoints:
|
||||
if use_mtime:
|
||||
@ -1583,12 +1584,12 @@ class Trainer:
|
||||
)
|
||||
return checkpoints_sorted
|
||||
|
||||
def _rotate_checkpoints(self, use_mtime=False) -> None:
|
||||
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
|
||||
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
|
||||
return
|
||||
|
||||
# Check if we should delete older checkpoint(s)
|
||||
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
|
||||
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
|
||||
if len(checkpoints_sorted) <= self.args.save_total_limit:
|
||||
return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user