Report value for a step instead of epoch. (#18095)

* Report value for a step instead of epoch.

Report an objective function value for a step instead of epoch to optuna.
I made this modification for the following reason:
If "eval_steps" is less than steps per epoch, there maybe warnings like this: "optuna/trial/_trial.py:592: UserWarning: The reported value is ignored because this `step` 0 is already reported.". So "step" are more appropriate than "epoch" here.

* MOD: make style.

Co-authored-by: zhaowei01 <zhaowei01@yuanfudao.com>
This commit is contained in:
wei zhao 2022-07-12 20:18:35 +08:00 committed by GitHub
parent d4ebd4e112
commit f5221c06e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1138,16 +1138,14 @@ class Trainer:
self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
self.args.hf_deepspeed_config.trainer_config_process(self.args)
def _report_to_hp_search(
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
):
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None:
return
self.objective = self.compute_objective(metrics.copy())
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna
trial.report(self.objective, epoch)
trial.report(self.objective, step)
if trial.should_prune():
self.callback_handler.on_train_end(self.args, self.state, self.control)
raise optuna.TrialPruned()
@ -1918,7 +1916,7 @@ class Trainer:
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, epoch, metrics)
self._report_to_hp_search(trial, self.state.global_step, metrics)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)