mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Only log total_flos at the end of training (#7981)
* Only log total_flos at the end of training * Fix test
This commit is contained in:
parent
ff65beafa3
commit
06fc3954a1
@ -830,6 +830,10 @@ class Trainer:
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
if self._total_flos is not None:
|
||||
self.store_flos()
|
||||
self.log({"total_flos": self.state.total_flos})
|
||||
|
||||
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
|
||||
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||
@ -1013,9 +1017,6 @@ class Trainer:
|
||||
return self._log(logs)
|
||||
if self.state.epoch is not None:
|
||||
logs["epoch"] = self.state.epoch
|
||||
if self._total_flos is not None:
|
||||
self.store_flos()
|
||||
logs["total_flos"] = self.state.total_flos
|
||||
|
||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
output = {**logs, **{"step": self.state.global_step}}
|
||||
|
@ -114,12 +114,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
||||
metrics = copy.deepcopy(metrics)
|
||||
loss = metrics.pop("eval_loss", None)
|
||||
_ = metrics.pop("epoch", None)
|
||||
_ = metrics.pop("total_flos", None)
|
||||
if len(metrics) != 0:
|
||||
raise RuntimeError(
|
||||
"Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function."
|
||||
)
|
||||
return loss
|
||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||
|
||||
|
||||
def default_hp_space_optuna(trial) -> Dict[str, float]:
|
||||
|
@ -125,7 +125,7 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
expected_events.append("on_epoch_end")
|
||||
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
expected_events += evaluation_events.copy()
|
||||
expected_events.append("on_train_end")
|
||||
expected_events += ["on_log", "on_train_end"]
|
||||
return expected_events
|
||||
|
||||
def test_init_callback(self):
|
||||
|
Loading…
Reference in New Issue
Block a user