Only log total_flos at the end of training (#7981)

* Only log total_flos at the end of training

* Fix test
This commit is contained in:
Sylvain Gugger 2020-10-22 14:26:55 -04:00 committed by GitHub
parent ff65beafa3
commit 06fc3954a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 10 deletions

View File

@ -830,6 +830,10 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
if self._total_flos is not None:
self.store_flos()
self.log({"total_flos": self.state.total_flos})
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
@ -1013,9 +1017,6 @@ class Trainer:
return self._log(logs)
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self._total_flos is not None:
self.store_flos()
logs["total_flos"] = self.state.total_flos
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
output = {**logs, **{"step": self.state.global_step}}

View File

@ -114,12 +114,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
metrics = copy.deepcopy(metrics)
loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None)
_ = metrics.pop("total_flos", None)
if len(metrics) != 0:
raise RuntimeError(
"Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function."
)
return loss
return loss if len(metrics) == 0 else sum(metrics.values())
def default_hp_space_optuna(trial) -> Dict[str, float]:

View File

@ -125,7 +125,7 @@ class TrainerCallbackTest(unittest.TestCase):
expected_events.append("on_epoch_end")
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
expected_events += evaluation_events.copy()
expected_events.append("on_train_end")
expected_events += ["on_log", "on_train_end"]
return expected_events
def test_init_callback(self):